diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml index 2596a18c..d21d0684 100644 --- a/.github/workflows/backend-ci.yml +++ b/.github/workflows/backend-ci.yml @@ -11,8 +11,8 @@ jobs: test: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 + - uses: actions/checkout@v6 + - uses: actions/setup-go@v6 with: go-version-file: backend/go.mod check-latest: false @@ -30,8 +30,8 @@ jobs: golangci-lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 + - uses: actions/checkout@v6 + - uses: actions/setup-go@v6 with: go-version-file: backend/go.mod check-latest: false @@ -43,5 +43,5 @@ jobs: uses: golangci/golangci-lint-action@v9 with: version: v2.7 - args: --timeout=5m - working-directory: backend + args: --timeout=30m + working-directory: backend \ No newline at end of file diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 50bb73e0..a1c6aa23 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -31,7 +31,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Update VERSION file run: | @@ -45,7 +45,7 @@ jobs: echo "Updated VERSION file to: $VERSION" - name: Upload VERSION artifact - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v7 with: name: version-file path: backend/cmd/server/VERSION @@ -55,7 +55,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup pnpm uses: pnpm/action-setup@v4 @@ -63,7 +63,7 @@ jobs: version: 9 - name: Setup Node.js - uses: actions/setup-node@v4 + uses: actions/setup-node@v6 with: node-version: '20' cache: 'pnpm' @@ -78,7 +78,7 @@ jobs: working-directory: frontend - name: Upload frontend artifact - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v7 with: name: frontend-dist path: backend/internal/web/dist/ @@ -89,25 +89,25 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 ref: ${{ github.event.inputs.tag || github.ref }} - name: Download VERSION artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v8 with: name: version-file path: backend/cmd/server/ - name: Download frontend artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v8 with: name: frontend-dist path: backend/internal/web/dist/ - name: Setup Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version-file: backend/go.mod check-latest: false @@ -173,7 +173,7 @@ jobs: run: echo "owner=$(echo '${{ github.repository_owner }}' | tr '[:upper:]' '[:lower:]')" >> $GITHUB_OUTPUT - name: Run GoReleaser - uses: goreleaser/goreleaser-action@v6 + uses: goreleaser/goreleaser-action@v7 with: version: '~> v2' args: release --clean --skip=validate ${{ env.SIMPLE_RELEASE == 'true' && '--config=.goreleaser.simple.yaml' || '' }} @@ -188,7 +188,7 @@ jobs: # Update DockerHub description - name: Update DockerHub description if: ${{ env.SIMPLE_RELEASE != 'true' && env.DOCKERHUB_USERNAME != '' }} - uses: peter-evans/dockerhub-description@v4 + uses: peter-evans/dockerhub-description@v5 env: DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} with: diff --git a/.github/workflows/security-scan.yml b/.github/workflows/security-scan.yml index 05dd1d1a..db922509 100644 --- a/.github/workflows/security-scan.yml +++ b/.github/workflows/security-scan.yml @@ -12,10 +12,11 @@ permissions: jobs: backend-security: runs-on: ubuntu-latest + timeout-minutes: 15 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version-file: backend/go.mod check-latest: false @@ -28,22 +29,17 @@ jobs: run: | go install golang.org/x/vuln/cmd/govulncheck@latest govulncheck ./... - - name: Run gosec - working-directory: backend - run: | - go install github.com/securego/gosec/v2/cmd/gosec@latest - gosec -severity high -confidence high ./... frontend-security: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up pnpm uses: pnpm/action-setup@v4 with: version: 9 - name: Set up Node.js - uses: actions/setup-node@v4 + uses: actions/setup-node@v6 with: node-version: '20' cache: 'pnpm' diff --git a/.gitignore b/.gitignore index 48172982..297c1d6f 100644 --- a/.gitignore +++ b/.gitignore @@ -116,17 +116,20 @@ backend/.installed # =================== tests CLAUDE.md -AGENTS.md .claude scripts .code-review-state -openspec/ -docs/ +#openspec/ code-reviews/ -AGENTS.md +#AGENTS.md backend/cmd/server/server deploy/docker-compose.override.yml .gocache/ vite.config.js docs/* -.serena/ \ No newline at end of file +.serena/ +.codex/ +frontend/coverage/ +aicodex +output/ + diff --git a/DEV_GUIDE.md b/DEV_GUIDE.md index 541bf1fa..d0d362e0 100644 --- a/DEV_GUIDE.md +++ b/DEV_GUIDE.md @@ -209,7 +209,30 @@ git add ent/ # 生成的文件也要提交 --- -### 坑 10:PR 提交前检查清单 +### 坑 10:前端测试看似正常,但后端调用失败(模型映射被批量误改) + +**典型现象**: +- 前端按钮点测看起来正常; +- 实际通过 API/客户端调用时返回 `Service temporarily unavailable` 或提示无可用账号; +- 常见于 OpenAI 账号(例如 Codex 模型)在批量修改后突然不可用。 + +**根因**: +- OpenAI 账号编辑页默认不显式展示映射规则,容易让人误以为“没映射也没关系”; +- 但在**批量修改同时选中不同平台账号**(OpenAI + Antigravity/Gemini)时,模型白名单/映射可能被跨平台策略覆盖; +- 结果是 OpenAI 账号的关键模型映射丢失或被改坏,后端选不到可用账号。 + +**修复方案(按优先级)**: +1. **快速修复(推荐)**:在批量修改中补回正确的透传映射(例如 `gpt-5.3-codex -> gpt-5.3-codex-spark`)。 +2. **彻底重建**:删除并重新添加全部相关账号(最稳但成本高)。 + +**关键经验**: +- 如果某模型已被软件内置默认映射覆盖,通常不需要额外再加透传; +- 但当上游模型更新快于本仓库默认映射时,**手动批量添加透传映射**是最简单、最低风险的临时兜底方案; +- 批量操作前尽量按平台分组,不要混选不同平台账号。 + +--- + +### 坑 11:PR 提交前检查清单 提交 PR 前务必本地验证: diff --git a/Dockerfile b/Dockerfile index c9fcf301..1493e8a7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,7 +8,7 @@ ARG NODE_IMAGE=node:24-alpine ARG GOLANG_IMAGE=golang:1.25.7-alpine -ARG ALPINE_IMAGE=alpine:3.20 +ARG ALPINE_IMAGE=alpine:3.21 ARG GOPROXY=https://goproxy.cn,direct ARG GOSUMDB=sum.golang.google.cn @@ -36,7 +36,7 @@ RUN pnpm run build FROM ${GOLANG_IMAGE} AS backend-builder # Build arguments for version info (set by CI) -ARG VERSION=docker +ARG VERSION= ARG COMMIT=docker ARG DATE ARG GOPROXY @@ -61,9 +61,14 @@ COPY backend/ ./ COPY --from=frontend-builder /app/backend/internal/web/dist ./internal/web/dist # Build the binary (BuildType=release for CI builds, embed frontend) -RUN CGO_ENABLED=0 GOOS=linux go build \ +# Version precedence: build arg VERSION > cmd/server/VERSION +RUN VERSION_VALUE="${VERSION}" && \ + if [ -z "${VERSION_VALUE}" ]; then VERSION_VALUE="$(tr -d '\r\n' < ./cmd/server/VERSION)"; fi && \ + DATE_VALUE="${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)}" && \ + CGO_ENABLED=0 GOOS=linux go build \ -tags embed \ - -ldflags="-s -w -X main.Commit=${COMMIT} -X main.Date=${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)} -X main.BuildType=release" \ + -ldflags="-s -w -X main.Version=${VERSION_VALUE} -X main.Commit=${COMMIT} -X main.Date=${DATE_VALUE} -X main.BuildType=release" \ + -trimpath \ -o /app/sub2api \ ./cmd/server @@ -81,7 +86,6 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api" RUN apk add --no-cache \ ca-certificates \ tzdata \ - curl \ && rm -rf /var/cache/apk/* # Create non-root user @@ -91,11 +95,12 @@ RUN addgroup -g 1000 sub2api && \ # Set working directory WORKDIR /app -# Copy binary from builder -COPY --from=backend-builder /app/sub2api /app/sub2api +# Copy binary/resources with ownership to avoid extra full-layer chown copy +COPY --from=backend-builder --chown=sub2api:sub2api /app/sub2api /app/sub2api +COPY --from=backend-builder --chown=sub2api:sub2api /app/backend/resources /app/resources # Create data directory -RUN mkdir -p /app/data && chown -R sub2api:sub2api /app +RUN mkdir -p /app/data && chown sub2api:sub2api /app/data # Switch to non-root user USER sub2api @@ -105,7 +110,7 @@ EXPOSE 8080 # Health check HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \ - CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1 + CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1 # Run the application ENTRYPOINT ["/app/sub2api"] diff --git a/Makefile b/Makefile index a5e18a37..fd6a5a9a 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: build build-backend build-frontend test test-backend test-frontend +.PHONY: build build-backend build-frontend build-datamanagementd test test-backend test-frontend test-datamanagementd secret-scan # 一键编译前后端 build: build-backend build-frontend @@ -11,6 +11,10 @@ build-backend: build-frontend: @pnpm --dir frontend run build +# 编译 datamanagementd(宿主机数据管理进程) +build-datamanagementd: + @cd datamanagement && go build -o datamanagementd ./cmd/datamanagementd + # 运行测试(后端 + 前端) test: test-backend test-frontend @@ -20,3 +24,9 @@ test-backend: test-frontend: @pnpm --dir frontend run lint:check @pnpm --dir frontend run typecheck + +test-datamanagementd: + @cd datamanagement && go test ./... + +secret-scan: + @python3 tools/secret_scan.py diff --git a/README.md b/README.md index 36949b0a..1e2f2290 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot ## Documentation - Dependency Security: `docs/dependency-security.md` +- Admin Payment Integration API: `docs/ADMIN_PAYMENT_INTEGRATION_API.md` --- @@ -363,6 +364,12 @@ default: rate_multiplier: 1.0 ``` +### Sora Status (Temporarily Unavailable) + +> ⚠️ Sora-related features are temporarily unavailable due to technical issues in upstream integration and media delivery. +> Please do not rely on Sora in production at this time. +> Existing `gateway.sora_*` configuration keys are reserved and may not take effect until these issues are resolved. + Additional security-related options are available in `config.yaml`: - `cors.allowed_origins` for CORS allowlist diff --git a/README_CN.md b/README_CN.md index 1e0d1d62..316cab94 100644 --- a/README_CN.md +++ b/README_CN.md @@ -62,8 +62,6 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅( - 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。 - 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。 ---- - ## 部署方式 ### 方式一:脚本安装(推荐) @@ -244,6 +242,18 @@ docker-compose -f docker-compose.local.yml logs -f sub2api **推荐:** 使用 `docker-compose.local.yml`(脚本部署)以便更轻松地管理数据。 +#### 启用“数据管理”功能(datamanagementd) + +如需启用管理后台“数据管理”,需要额外部署宿主机数据管理进程 `datamanagementd`。 + +关键点: + +- 主进程固定探测:`/tmp/sub2api-datamanagement.sock` +- 只有该 Socket 可连通时,数据管理功能才会开启 +- Docker 场景需将宿主机 Socket 挂载到容器同路径 + +详细部署步骤见:`deploy/DATAMANAGEMENTD_CN.md` + #### 访问 在浏览器中打开 `http://你的服务器IP:8080` @@ -370,6 +380,33 @@ default: rate_multiplier: 1.0 ``` +### Sora 功能状态(暂不可用) + +> ⚠️ 当前 Sora 相关功能因上游接入与媒体链路存在技术问题,暂时不可用。 +> 现阶段请勿在生产环境依赖 Sora 能力。 +> 文档中的 `gateway.sora_*` 配置仅作预留,待技术问题修复后再恢复可用。 + +### Sora 媒体签名 URL(功能恢复后可选) + +当配置 `gateway.sora_media_signing_key` 且 `gateway.sora_media_signed_url_ttl_seconds > 0` 时,网关会将 Sora 输出的媒体地址改写为临时签名 URL(`/sora/media-signed/...`)。这样无需 API Key 即可在浏览器中直接访问,且具备过期控制与防篡改能力(签名包含 path + query)。 + +```yaml +gateway: + # /sora/media 是否强制要求 API Key(默认 false) + sora_media_require_api_key: false + # 媒体临时签名密钥(为空则禁用签名) + sora_media_signing_key: "your-signing-key" + # 临时签名 URL 有效期(秒) + sora_media_signed_url_ttl_seconds: 900 +``` + +> 若未配置签名密钥,`/sora/media-signed` 将返回 503。 +> 如需更严格的访问控制,可将 `sora_media_require_api_key` 设为 true,仅允许携带 API Key 的 `/sora/media` 访问。 + +访问策略说明: +- `/sora/media`:内部调用或客户端携带 API Key 才能下载 +- `/sora/media-signed`:外部可访问,但有签名 + 过期控制 + `config.yaml` 还支持以下安全相关配置: - `cors.allowed_origins` 配置 CORS 白名单 @@ -383,6 +420,14 @@ default: - `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For - `turnstile.required` 在 release 模式强制启用 Turnstile +**网关防御纵深建议(重点)** + +- `gateway.upstream_response_read_max_bytes`:限制非流式上游响应读取大小(默认 `8MB`),用于防止异常响应导致内存放大。 +- `gateway.proxy_probe_response_read_max_bytes`:限制代理探测响应读取大小(默认 `1MB`)。 +- `gateway.gemini_debug_response_headers`:默认 `false`,仅在排障时短时开启,避免高频请求日志开销。 +- `/auth/register`、`/auth/login`、`/auth/login/2fa`、`/auth/send-verify-code` 已提供服务端兜底限流(Redis 故障时 fail-close)。 +- 推荐将 WAF/CDN 作为第一层防护,服务端限流与响应读取上限作为第二层兜底;两层同时保留,避免旁路流量与误配置风险。 + **⚠️ 安全警告:HTTP URL 配置** 当 `security.url_allowlist.enabled=false` 时,系统默认执行最小 URL 校验,**拒绝 HTTP URL**,仅允许 HTTPS。要允许 HTTP URL(例如用于开发或内网测试),必须显式设置: @@ -428,6 +473,29 @@ Invalid base URL: invalid url scheme: http ./sub2api ``` +#### HTTP/2 (h2c) 与 HTTP/1.1 回退 + +后端明文端口默认支持 h2c,并保留 HTTP/1.1 回退用于 WebSocket 与旧客户端。浏览器通常不支持 h2c,性能收益主要在反向代理或内网链路。 + +**反向代理示例(Caddy):** + +```caddyfile +transport http { + versions h2c h1 +} +``` + +**验证:** + +```bash +# h2c prior knowledge +curl --http2-prior-knowledge -I http://localhost:8080/health +# HTTP/1.1 回退 +curl --http1.1 -I http://localhost:8080/health +# WebSocket 回退验证(需管理员 token) +websocat -H="Sec-WebSocket-Protocol: sub2api-admin, jwt." ws://localhost:8080/api/v1/admin/ops/ws/qps +``` + #### 开发模式 ```bash diff --git a/backend/.golangci.yml b/backend/.golangci.yml index 3ec692a8..68b76751 100644 --- a/backend/.golangci.yml +++ b/backend/.golangci.yml @@ -5,6 +5,7 @@ linters: enable: - depguard - errcheck + - gosec - govet - ineffassign - staticcheck @@ -42,6 +43,22 @@ linters: desc: "handler must not import gorm" - pkg: github.com/redis/go-redis/v9 desc: "handler must not import redis" + gosec: + excludes: + - G101 + - G103 + - G104 + - G109 + - G115 + - G201 + - G202 + - G301 + - G302 + - G304 + - G306 + - G404 + severity: high + confidence: high errcheck: # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. # Such cases aren't reported by default. diff --git a/backend/Makefile b/backend/Makefile index 6a5d2caa..7084ccb9 100644 --- a/backend/Makefile +++ b/backend/Makefile @@ -1,7 +1,14 @@ -.PHONY: build test test-unit test-integration test-e2e +.PHONY: build generate test test-unit test-integration test-e2e + +VERSION ?= $(shell tr -d '\r\n' < ./cmd/server/VERSION) +LDFLAGS ?= -s -w -X main.Version=$(VERSION) build: - go build -o bin/server ./cmd/server + CGO_ENABLED=0 go build -ldflags="$(LDFLAGS)" -trimpath -o bin/server ./cmd/server + +generate: + go generate ./ent + go generate ./cmd/server test: go test ./... @@ -14,4 +21,7 @@ test-integration: go test -tags=integration ./... test-e2e: - go test -tags=e2e ./... + ./scripts/e2e-test.sh + +test-e2e-local: + go test -tags=e2e -v -timeout=300s ./internal/integration/... diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go index ce4718bf..bc001693 100644 --- a/backend/cmd/jwtgen/main.go +++ b/backend/cmd/jwtgen/main.go @@ -17,7 +17,7 @@ func main() { email := flag.String("email", "", "Admin email to issue a JWT for (defaults to first active admin)") flag.Parse() - cfg, err := config.Load() + cfg, err := config.LoadForBootstrap() if err != nil { log.Fatalf("failed to load config: %v", err) } @@ -33,7 +33,7 @@ func main() { }() userRepo := repository.NewUserRepository(client, sqlDB) - authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil) + authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 5087e794..32844913 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.76 \ No newline at end of file +0.1.88 \ No newline at end of file diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index f8a7d313..46edcb69 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -8,7 +8,6 @@ import ( "errors" "flag" "log" - "log/slog" "net/http" "os" "os/signal" @@ -19,11 +18,14 @@ import ( _ "github.com/Wei-Shaw/sub2api/ent/runtime" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/setup" "github.com/Wei-Shaw/sub2api/internal/web" "github.com/gin-gonic/gin" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" ) //go:embed VERSION @@ -38,7 +40,12 @@ var ( ) func init() { - // Read version from embedded VERSION file + // 如果 Version 已通过 ldflags 注入(例如 -X main.Version=...),则不要覆盖。 + if strings.TrimSpace(Version) != "" { + return + } + + // 默认从 embedded VERSION 文件读取版本号(编译期打包进二进制)。 Version = strings.TrimSpace(embeddedVersion) if Version == "" { Version = "0.0.0-dev" @@ -47,22 +54,9 @@ func init() { // initLogger configures the default slog handler based on gin.Mode(). // In non-release mode, Debug level logs are enabled. -func initLogger() { - var level slog.Level - if gin.Mode() == gin.ReleaseMode { - level = slog.LevelInfo - } else { - level = slog.LevelDebug - } - handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ - Level: level, - }) - slog.SetDefault(slog.New(handler)) -} - func main() { - // Initialize slog logger based on gin mode - initLogger() + logger.InitBootstrap() + defer logger.Sync() // Parse command line flags setupMode := flag.Bool("setup", false, "Run setup wizard in CLI mode") @@ -106,7 +100,7 @@ func runSetupServer() { r := gin.New() r.Use(middleware.Recovery()) r.Use(middleware.CORS(config.CORSConfig{})) - r.Use(middleware.SecurityHeaders(config.CSPConfig{Enabled: true, Policy: config.DefaultCSPPolicy})) + r.Use(middleware.SecurityHeaders(config.CSPConfig{Enabled: true, Policy: config.DefaultCSPPolicy}, nil)) // Register setup routes setup.RegisterRoutes(r) @@ -122,16 +116,26 @@ func runSetupServer() { log.Printf("Setup wizard available at http://%s", addr) log.Println("Complete the setup wizard to configure Sub2API") - if err := r.Run(addr); err != nil { + server := &http.Server{ + Addr: addr, + Handler: h2c.NewHandler(r, &http2.Server{}), + ReadHeaderTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, + } + + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { log.Fatalf("Failed to start setup server: %v", err) } } func runMainServer() { - cfg, err := config.Load() + cfg, err := config.LoadForBootstrap() if err != nil { log.Fatalf("Failed to load config: %v", err) } + if err := logger.Init(logger.OptionsFromConfig(cfg.Log)); err != nil { + log.Fatalf("Failed to initialize logger: %v", err) + } if cfg.RunMode == config.RunModeSimple { log.Println("⚠️ WARNING: Running in SIMPLE mode - billing and quota checks are DISABLED") } diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index d9ff788e..cbf89ba3 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -7,6 +7,7 @@ import ( "context" "log" "net/http" + "sync" "time" "github.com/Wei-Shaw/sub2api/ent" @@ -67,28 +68,36 @@ func provideCleanup( opsAlertEvaluator *service.OpsAlertEvaluatorService, opsCleanup *service.OpsCleanupService, opsScheduledReport *service.OpsScheduledReportService, + opsSystemLogSink *service.OpsSystemLogSink, + soraMediaCleanup *service.SoraMediaCleanupService, schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, accountExpiry *service.AccountExpiryService, subscriptionExpiry *service.SubscriptionExpiryService, usageCleanup *service.UsageCleanupService, + idempotencyCleanup *service.IdempotencyCleanupService, pricing *service.PricingService, emailQueue *service.EmailQueueService, billingCache *service.BillingCacheService, + usageRecordWorkerPool *service.UsageRecordWorkerPool, + subscriptionService *service.SubscriptionService, oauth *service.OAuthService, openaiOAuth *service.OpenAIOAuthService, geminiOAuth *service.GeminiOAuthService, antigravityOAuth *service.AntigravityOAuthService, + openAIGateway *service.OpenAIGatewayService, ) func() { return func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - // Cleanup steps in reverse dependency order - cleanupSteps := []struct { + type cleanupStep struct { name string fn func() error - }{ + } + + // 应用层清理步骤可并行执行,基础设施资源(Redis/Ent)最后按顺序关闭。 + parallelSteps := []cleanupStep{ {"OpsScheduledReportService", func() error { if opsScheduledReport != nil { opsScheduledReport.Stop() @@ -101,6 +110,18 @@ func provideCleanup( } return nil }}, + {"OpsSystemLogSink", func() error { + if opsSystemLogSink != nil { + opsSystemLogSink.Stop() + } + return nil + }}, + {"SoraMediaCleanupService", func() error { + if soraMediaCleanup != nil { + soraMediaCleanup.Stop() + } + return nil + }}, {"OpsAlertEvaluatorService", func() error { if opsAlertEvaluator != nil { opsAlertEvaluator.Stop() @@ -131,6 +152,12 @@ func provideCleanup( } return nil }}, + {"IdempotencyCleanupService", func() error { + if idempotencyCleanup != nil { + idempotencyCleanup.Stop() + } + return nil + }}, {"TokenRefreshService", func() error { tokenRefresh.Stop() return nil @@ -143,6 +170,12 @@ func provideCleanup( subscriptionExpiry.Stop() return nil }}, + {"SubscriptionService", func() error { + if subscriptionService != nil { + subscriptionService.Stop() + } + return nil + }}, {"PricingService", func() error { pricing.Stop() return nil @@ -155,6 +188,12 @@ func provideCleanup( billingCache.Stop() return nil }}, + {"UsageRecordWorkerPool", func() error { + if usageRecordWorkerPool != nil { + usageRecordWorkerPool.Stop() + } + return nil + }}, {"OAuthService", func() error { oauth.Stop() return nil @@ -171,23 +210,60 @@ func provideCleanup( antigravityOAuth.Stop() return nil }}, + {"OpenAIWSPool", func() error { + if openAIGateway != nil { + openAIGateway.CloseOpenAIWSPool() + } + return nil + }}, + } + + infraSteps := []cleanupStep{ {"Redis", func() error { + if rdb == nil { + return nil + } return rdb.Close() }}, {"Ent", func() error { + if entClient == nil { + return nil + } return entClient.Close() }}, } - for _, step := range cleanupSteps { - if err := step.fn(); err != nil { - log.Printf("[Cleanup] %s failed: %v", step.name, err) - // Continue with remaining cleanup steps even if one fails - } else { + runParallel := func(steps []cleanupStep) { + var wg sync.WaitGroup + for i := range steps { + step := steps[i] + wg.Add(1) + go func() { + defer wg.Done() + if err := step.fn(); err != nil { + log.Printf("[Cleanup] %s failed: %v", step.name, err) + return + } + log.Printf("[Cleanup] %s succeeded", step.name) + }() + } + wg.Wait() + } + + runSequential := func(steps []cleanupStep) { + for i := range steps { + step := steps[i] + if err := step.fn(); err != nil { + log.Printf("[Cleanup] %s failed: %v", step.name, err) + continue + } log.Printf("[Cleanup] %s succeeded", step.name) } } + runParallel(parallelSteps) + runSequential(infraSteps) + // Check if context timed out select { case <-ctx.Done(): diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 5ccd797e..8e7aefe1 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -19,6 +19,7 @@ import ( "github.com/redis/go-redis/v9" "log" "net/http" + "sync" "time" ) @@ -47,7 +48,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { redisClient := repository.ProvideRedis(configConfig) refreshTokenCache := repository.NewRefreshTokenCache(redisClient) settingRepository := repository.NewSettingRepository(client) - settingService := service.NewSettingService(settingRepository, configConfig) + groupRepository := repository.NewGroupRepository(client, db) + settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig) emailCache := repository.NewEmailCache(redisClient) emailService := service.NewEmailService(settingRepository, emailCache) turnstileVerifier := repository.NewTurnstileVerifier() @@ -56,17 +58,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { promoCodeRepository := repository.NewPromoCodeRepository(client) billingCache := repository.NewBillingCache(redisClient) userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) - billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) - apiKeyRepository := repository.NewAPIKeyRepository(client) - groupRepository := repository.NewGroupRepository(client, db) + apiKeyRepository := repository.NewAPIKeyRepository(client, db) + billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, configConfig) userGroupRateRepository := repository.NewUserGroupRateRepository(db) apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig) + apiKeyService.SetRateLimitCacheInvalidator(billingCache) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) - authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) - userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) - subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) + subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) + authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService) + userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache) redeemCache := repository.NewRedeemCache(redisClient) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) secretEncryptor, err := repository.NewAESEncryptor(configConfig) @@ -98,10 +100,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService) schedulerCache := repository.NewSchedulerCache(redisClient) accountRepository := repository.NewAccountRepository(client, db, schedulerCache) + soraAccountRepository := repository.NewSoraAccountRepository(db) proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) @@ -112,7 +115,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient) geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() - geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig) + driveClient := repository.NewGeminiDriveClient() + geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig) antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository) geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository) tempUnschedCache := repository.NewTempUnschedCache(redisClient) @@ -135,14 +139,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) - accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator) + rpmCache := repository.NewRPMCache(redisClient) + accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator) adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) + dataManagementService := service.NewDataManagementService() + dataManagementHandler := admin.NewDataManagementHandler(dataManagementService) oAuthHandler := admin.NewOAuthHandler(oAuthService) openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService) antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService) proxyHandler := admin.NewProxyHandler(adminService) - adminRedeemHandler := admin.NewRedeemHandler(adminService) + adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService) promoHandler := admin.NewPromoHandler(promoService) opsRepository := repository.NewOpsRepository(db) pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig) @@ -155,18 +162,26 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) digestSessionStore := service.NewDigestSessionStore() - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, digestSessionStore) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) - opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService) - settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService) + opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) + opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) + soraS3Storage := service.NewSoraS3Storage(settingService) + settingService.SetOnS3UpdateCallback(soraS3Storage.RefreshClient) + soraGenerationRepository := repository.NewSoraGenerationRepository(db) + soraQuotaService := service.NewSoraQuotaService(userRepository, groupRepository, settingService) + soraGenerationService := service.NewSoraGenerationService(soraGenerationRepository, soraS3Storage, soraQuotaService) + settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, soraS3Storage) opsHandler := admin.NewOpsHandler(opsService) updateCache := repository.NewUpdateCache(redisClient) gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig) serviceBuildInfo := provideServiceBuildInfo(buildInfo) updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo) - systemHandler := handler.ProvideSystemHandler(updateService) + idempotencyRepository := repository.NewIdempotencyRepository(client, db) + systemOperationLockService := service.ProvideSystemOperationLockService(idempotencyRepository, configConfig) + systemHandler := handler.ProvideSystemHandler(updateService, systemOperationLockService) adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService) usageCleanupRepository := repository.NewUsageCleanupRepository(client, db) usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig) @@ -179,12 +194,23 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient) errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache) errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig) - openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig) + adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler) + usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) + userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) + userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService) + openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) + soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository) + soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) + soraGatewayService := service.NewSoraGatewayService(soraSDKClient, rateLimitService, httpUpstream, configConfig) + soraClientHandler := handler.NewSoraClientHandler(soraGenerationService, soraQuotaService, soraS3Storage, soraGatewayService, gatewayService, soraMediaStorage, apiKeyService) + soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) totpHandler := handler.NewTotpHandler(totpService) - handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler) + idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig) + idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig) + handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, soraClientHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) @@ -195,10 +221,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) - tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig) + soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService) application := &Application{ Server: httpServer, Cleanup: v, @@ -228,27 +255,35 @@ func provideCleanup( opsAlertEvaluator *service.OpsAlertEvaluatorService, opsCleanup *service.OpsCleanupService, opsScheduledReport *service.OpsScheduledReportService, + opsSystemLogSink *service.OpsSystemLogSink, + soraMediaCleanup *service.SoraMediaCleanupService, schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, accountExpiry *service.AccountExpiryService, subscriptionExpiry *service.SubscriptionExpiryService, usageCleanup *service.UsageCleanupService, + idempotencyCleanup *service.IdempotencyCleanupService, pricing *service.PricingService, emailQueue *service.EmailQueueService, billingCache *service.BillingCacheService, + usageRecordWorkerPool *service.UsageRecordWorkerPool, + subscriptionService *service.SubscriptionService, oauth *service.OAuthService, openaiOAuth *service.OpenAIOAuthService, geminiOAuth *service.GeminiOAuthService, antigravityOAuth *service.AntigravityOAuthService, + openAIGateway *service.OpenAIGatewayService, ) func() { return func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - cleanupSteps := []struct { + type cleanupStep struct { name string fn func() error - }{ + } + + parallelSteps := []cleanupStep{ {"OpsScheduledReportService", func() error { if opsScheduledReport != nil { opsScheduledReport.Stop() @@ -261,6 +296,18 @@ func provideCleanup( } return nil }}, + {"OpsSystemLogSink", func() error { + if opsSystemLogSink != nil { + opsSystemLogSink.Stop() + } + return nil + }}, + {"SoraMediaCleanupService", func() error { + if soraMediaCleanup != nil { + soraMediaCleanup.Stop() + } + return nil + }}, {"OpsAlertEvaluatorService", func() error { if opsAlertEvaluator != nil { opsAlertEvaluator.Stop() @@ -291,6 +338,12 @@ func provideCleanup( } return nil }}, + {"IdempotencyCleanupService", func() error { + if idempotencyCleanup != nil { + idempotencyCleanup.Stop() + } + return nil + }}, {"TokenRefreshService", func() error { tokenRefresh.Stop() return nil @@ -303,6 +356,12 @@ func provideCleanup( subscriptionExpiry.Stop() return nil }}, + {"SubscriptionService", func() error { + if subscriptionService != nil { + subscriptionService.Stop() + } + return nil + }}, {"PricingService", func() error { pricing.Stop() return nil @@ -315,6 +374,12 @@ func provideCleanup( billingCache.Stop() return nil }}, + {"UsageRecordWorkerPool", func() error { + if usageRecordWorkerPool != nil { + usageRecordWorkerPool.Stop() + } + return nil + }}, {"OAuthService", func() error { oauth.Stop() return nil @@ -331,23 +396,60 @@ func provideCleanup( antigravityOAuth.Stop() return nil }}, + {"OpenAIWSPool", func() error { + if openAIGateway != nil { + openAIGateway.CloseOpenAIWSPool() + } + return nil + }}, + } + + infraSteps := []cleanupStep{ {"Redis", func() error { + if rdb == nil { + return nil + } return rdb.Close() }}, {"Ent", func() error { + if entClient == nil { + return nil + } return entClient.Close() }}, } - for _, step := range cleanupSteps { - if err := step.fn(); err != nil { - log.Printf("[Cleanup] %s failed: %v", step.name, err) + runParallel := func(steps []cleanupStep) { + var wg sync.WaitGroup + for i := range steps { + step := steps[i] + wg.Add(1) + go func() { + defer wg.Done() + if err := step.fn(); err != nil { + log.Printf("[Cleanup] %s failed: %v", step.name, err) + return + } + log.Printf("[Cleanup] %s succeeded", step.name) + }() + } + wg.Wait() + } - } else { + runSequential := func(steps []cleanupStep) { + for i := range steps { + step := steps[i] + if err := step.fn(); err != nil { + log.Printf("[Cleanup] %s failed: %v", step.name, err) + continue + } log.Printf("[Cleanup] %s succeeded", step.name) } } + runParallel(parallelSteps) + runSequential(infraSteps) + select { case <-ctx.Done(): log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds") diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go new file mode 100644 index 00000000..373bfd88 --- /dev/null +++ b/backend/cmd/server/wire_gen_test.go @@ -0,0 +1,82 @@ +package main + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/handler" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestProvideServiceBuildInfo(t *testing.T) { + in := handler.BuildInfo{ + Version: "v-test", + BuildType: "release", + } + out := provideServiceBuildInfo(in) + require.Equal(t, in.Version, out.Version) + require.Equal(t, in.BuildType, out.BuildType) +} + +func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) { + cfg := &config.Config{} + + oauthSvc := service.NewOAuthService(nil, nil) + openAIOAuthSvc := service.NewOpenAIOAuthService(nil, nil) + geminiOAuthSvc := service.NewGeminiOAuthService(nil, nil, nil, nil, cfg) + antigravityOAuthSvc := service.NewAntigravityOAuthService(nil) + + tokenRefreshSvc := service.NewTokenRefreshService( + nil, + oauthSvc, + openAIOAuthSvc, + geminiOAuthSvc, + antigravityOAuthSvc, + nil, + nil, + cfg, + nil, + ) + accountExpirySvc := service.NewAccountExpiryService(nil, time.Second) + subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second) + pricingSvc := service.NewPricingService(cfg, nil) + emailQueueSvc := service.NewEmailQueueService(nil, 1) + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg) + idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg) + schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg) + opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil) + + cleanup := provideCleanup( + nil, // entClient + nil, // redis + &service.OpsMetricsCollector{}, + &service.OpsAggregationService{}, + &service.OpsAlertEvaluatorService{}, + &service.OpsCleanupService{}, + &service.OpsScheduledReportService{}, + opsSystemLogSinkSvc, + &service.SoraMediaCleanupService{}, + schedulerSnapshotSvc, + tokenRefreshSvc, + accountExpirySvc, + subscriptionExpirySvc, + &service.UsageCleanupService{}, + idempotencyCleanupSvc, + pricingSvc, + emailQueueSvc, + billingCacheSvc, + &service.UsageRecordWorkerPool{}, + &service.SubscriptionService{}, + oauthSvc, + openAIOAuthSvc, + geminiOAuthSvc, + antigravityOAuthSvc, + nil, // openAIGateway + ) + + require.NotPanics(t, func() { + cleanup() + }) +} diff --git a/backend/ent/account.go b/backend/ent/account.go index 038aa7e5..c77002b3 100644 --- a/backend/ent/account.go +++ b/backend/ent/account.go @@ -63,6 +63,10 @@ type Account struct { RateLimitResetAt *time.Time `json:"rate_limit_reset_at,omitempty"` // OverloadUntil holds the value of the "overload_until" field. OverloadUntil *time.Time `json:"overload_until,omitempty"` + // TempUnschedulableUntil holds the value of the "temp_unschedulable_until" field. + TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"` + // TempUnschedulableReason holds the value of the "temp_unschedulable_reason" field. + TempUnschedulableReason *string `json:"temp_unschedulable_reason,omitempty"` // SessionWindowStart holds the value of the "session_window_start" field. SessionWindowStart *time.Time `json:"session_window_start,omitempty"` // SessionWindowEnd holds the value of the "session_window_end" field. @@ -141,9 +145,9 @@ func (*Account) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullFloat64) case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority: values[i] = new(sql.NullInt64) - case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldSessionWindowStatus: + case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldTempUnschedulableReason, account.FieldSessionWindowStatus: values[i] = new(sql.NullString) - case account.FieldCreatedAt, account.FieldUpdatedAt, account.FieldDeletedAt, account.FieldLastUsedAt, account.FieldExpiresAt, account.FieldRateLimitedAt, account.FieldRateLimitResetAt, account.FieldOverloadUntil, account.FieldSessionWindowStart, account.FieldSessionWindowEnd: + case account.FieldCreatedAt, account.FieldUpdatedAt, account.FieldDeletedAt, account.FieldLastUsedAt, account.FieldExpiresAt, account.FieldRateLimitedAt, account.FieldRateLimitResetAt, account.FieldOverloadUntil, account.FieldTempUnschedulableUntil, account.FieldSessionWindowStart, account.FieldSessionWindowEnd: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -311,6 +315,20 @@ func (_m *Account) assignValues(columns []string, values []any) error { _m.OverloadUntil = new(time.Time) *_m.OverloadUntil = value.Time } + case account.FieldTempUnschedulableUntil: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field temp_unschedulable_until", values[i]) + } else if value.Valid { + _m.TempUnschedulableUntil = new(time.Time) + *_m.TempUnschedulableUntil = value.Time + } + case account.FieldTempUnschedulableReason: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field temp_unschedulable_reason", values[i]) + } else if value.Valid { + _m.TempUnschedulableReason = new(string) + *_m.TempUnschedulableReason = value.String + } case account.FieldSessionWindowStart: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field session_window_start", values[i]) @@ -472,6 +490,16 @@ func (_m *Account) String() string { builder.WriteString(v.Format(time.ANSIC)) } builder.WriteString(", ") + if v := _m.TempUnschedulableUntil; v != nil { + builder.WriteString("temp_unschedulable_until=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.TempUnschedulableReason; v != nil { + builder.WriteString("temp_unschedulable_reason=") + builder.WriteString(*v) + } + builder.WriteString(", ") if v := _m.SessionWindowStart; v != nil { builder.WriteString("session_window_start=") builder.WriteString(v.Format(time.ANSIC)) diff --git a/backend/ent/account/account.go b/backend/ent/account/account.go index 73c0e8c2..1fc34620 100644 --- a/backend/ent/account/account.go +++ b/backend/ent/account/account.go @@ -59,6 +59,10 @@ const ( FieldRateLimitResetAt = "rate_limit_reset_at" // FieldOverloadUntil holds the string denoting the overload_until field in the database. FieldOverloadUntil = "overload_until" + // FieldTempUnschedulableUntil holds the string denoting the temp_unschedulable_until field in the database. + FieldTempUnschedulableUntil = "temp_unschedulable_until" + // FieldTempUnschedulableReason holds the string denoting the temp_unschedulable_reason field in the database. + FieldTempUnschedulableReason = "temp_unschedulable_reason" // FieldSessionWindowStart holds the string denoting the session_window_start field in the database. FieldSessionWindowStart = "session_window_start" // FieldSessionWindowEnd holds the string denoting the session_window_end field in the database. @@ -128,6 +132,8 @@ var Columns = []string{ FieldRateLimitedAt, FieldRateLimitResetAt, FieldOverloadUntil, + FieldTempUnschedulableUntil, + FieldTempUnschedulableReason, FieldSessionWindowStart, FieldSessionWindowEnd, FieldSessionWindowStatus, @@ -299,6 +305,16 @@ func ByOverloadUntil(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldOverloadUntil, opts...).ToFunc() } +// ByTempUnschedulableUntil orders the results by the temp_unschedulable_until field. +func ByTempUnschedulableUntil(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTempUnschedulableUntil, opts...).ToFunc() +} + +// ByTempUnschedulableReason orders the results by the temp_unschedulable_reason field. +func ByTempUnschedulableReason(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTempUnschedulableReason, opts...).ToFunc() +} + // BySessionWindowStart orders the results by the session_window_start field. func BySessionWindowStart(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldSessionWindowStart, opts...).ToFunc() diff --git a/backend/ent/account/where.go b/backend/ent/account/where.go index dea1127a..54db1dcb 100644 --- a/backend/ent/account/where.go +++ b/backend/ent/account/where.go @@ -155,6 +155,16 @@ func OverloadUntil(v time.Time) predicate.Account { return predicate.Account(sql.FieldEQ(FieldOverloadUntil, v)) } +// TempUnschedulableUntil applies equality check predicate on the "temp_unschedulable_until" field. It's identical to TempUnschedulableUntilEQ. +func TempUnschedulableUntil(v time.Time) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableReason applies equality check predicate on the "temp_unschedulable_reason" field. It's identical to TempUnschedulableReasonEQ. +func TempUnschedulableReason(v string) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldTempUnschedulableReason, v)) +} + // SessionWindowStart applies equality check predicate on the "session_window_start" field. It's identical to SessionWindowStartEQ. func SessionWindowStart(v time.Time) predicate.Account { return predicate.Account(sql.FieldEQ(FieldSessionWindowStart, v)) @@ -1130,6 +1140,131 @@ func OverloadUntilNotNil() predicate.Account { return predicate.Account(sql.FieldNotNull(FieldOverloadUntil)) } +// TempUnschedulableUntilEQ applies the EQ predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilEQ(v time.Time) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilNEQ applies the NEQ predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilNEQ(v time.Time) predicate.Account { + return predicate.Account(sql.FieldNEQ(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilIn applies the In predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilIn(vs ...time.Time) predicate.Account { + return predicate.Account(sql.FieldIn(FieldTempUnschedulableUntil, vs...)) +} + +// TempUnschedulableUntilNotIn applies the NotIn predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilNotIn(vs ...time.Time) predicate.Account { + return predicate.Account(sql.FieldNotIn(FieldTempUnschedulableUntil, vs...)) +} + +// TempUnschedulableUntilGT applies the GT predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilGT(v time.Time) predicate.Account { + return predicate.Account(sql.FieldGT(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilGTE applies the GTE predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilGTE(v time.Time) predicate.Account { + return predicate.Account(sql.FieldGTE(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilLT applies the LT predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilLT(v time.Time) predicate.Account { + return predicate.Account(sql.FieldLT(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilLTE applies the LTE predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilLTE(v time.Time) predicate.Account { + return predicate.Account(sql.FieldLTE(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilIsNil applies the IsNil predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilIsNil() predicate.Account { + return predicate.Account(sql.FieldIsNull(FieldTempUnschedulableUntil)) +} + +// TempUnschedulableUntilNotNil applies the NotNil predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilNotNil() predicate.Account { + return predicate.Account(sql.FieldNotNull(FieldTempUnschedulableUntil)) +} + +// TempUnschedulableReasonEQ applies the EQ predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonEQ(v string) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonNEQ applies the NEQ predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonNEQ(v string) predicate.Account { + return predicate.Account(sql.FieldNEQ(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonIn applies the In predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonIn(vs ...string) predicate.Account { + return predicate.Account(sql.FieldIn(FieldTempUnschedulableReason, vs...)) +} + +// TempUnschedulableReasonNotIn applies the NotIn predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonNotIn(vs ...string) predicate.Account { + return predicate.Account(sql.FieldNotIn(FieldTempUnschedulableReason, vs...)) +} + +// TempUnschedulableReasonGT applies the GT predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonGT(v string) predicate.Account { + return predicate.Account(sql.FieldGT(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonGTE applies the GTE predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonGTE(v string) predicate.Account { + return predicate.Account(sql.FieldGTE(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonLT applies the LT predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonLT(v string) predicate.Account { + return predicate.Account(sql.FieldLT(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonLTE applies the LTE predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonLTE(v string) predicate.Account { + return predicate.Account(sql.FieldLTE(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonContains applies the Contains predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonContains(v string) predicate.Account { + return predicate.Account(sql.FieldContains(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonHasPrefix applies the HasPrefix predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonHasPrefix(v string) predicate.Account { + return predicate.Account(sql.FieldHasPrefix(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonHasSuffix applies the HasSuffix predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonHasSuffix(v string) predicate.Account { + return predicate.Account(sql.FieldHasSuffix(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonIsNil applies the IsNil predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonIsNil() predicate.Account { + return predicate.Account(sql.FieldIsNull(FieldTempUnschedulableReason)) +} + +// TempUnschedulableReasonNotNil applies the NotNil predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonNotNil() predicate.Account { + return predicate.Account(sql.FieldNotNull(FieldTempUnschedulableReason)) +} + +// TempUnschedulableReasonEqualFold applies the EqualFold predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonEqualFold(v string) predicate.Account { + return predicate.Account(sql.FieldEqualFold(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonContainsFold applies the ContainsFold predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonContainsFold(v string) predicate.Account { + return predicate.Account(sql.FieldContainsFold(FieldTempUnschedulableReason, v)) +} + // SessionWindowStartEQ applies the EQ predicate on the "session_window_start" field. func SessionWindowStartEQ(v time.Time) predicate.Account { return predicate.Account(sql.FieldEQ(FieldSessionWindowStart, v)) diff --git a/backend/ent/account_create.go b/backend/ent/account_create.go index 42a561cf..963ffee8 100644 --- a/backend/ent/account_create.go +++ b/backend/ent/account_create.go @@ -293,6 +293,34 @@ func (_c *AccountCreate) SetNillableOverloadUntil(v *time.Time) *AccountCreate { return _c } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (_c *AccountCreate) SetTempUnschedulableUntil(v time.Time) *AccountCreate { + _c.mutation.SetTempUnschedulableUntil(v) + return _c +} + +// SetNillableTempUnschedulableUntil sets the "temp_unschedulable_until" field if the given value is not nil. +func (_c *AccountCreate) SetNillableTempUnschedulableUntil(v *time.Time) *AccountCreate { + if v != nil { + _c.SetTempUnschedulableUntil(*v) + } + return _c +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (_c *AccountCreate) SetTempUnschedulableReason(v string) *AccountCreate { + _c.mutation.SetTempUnschedulableReason(v) + return _c +} + +// SetNillableTempUnschedulableReason sets the "temp_unschedulable_reason" field if the given value is not nil. +func (_c *AccountCreate) SetNillableTempUnschedulableReason(v *string) *AccountCreate { + if v != nil { + _c.SetTempUnschedulableReason(*v) + } + return _c +} + // SetSessionWindowStart sets the "session_window_start" field. func (_c *AccountCreate) SetSessionWindowStart(v time.Time) *AccountCreate { _c.mutation.SetSessionWindowStart(v) @@ -639,6 +667,14 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) { _spec.SetField(account.FieldOverloadUntil, field.TypeTime, value) _node.OverloadUntil = &value } + if value, ok := _c.mutation.TempUnschedulableUntil(); ok { + _spec.SetField(account.FieldTempUnschedulableUntil, field.TypeTime, value) + _node.TempUnschedulableUntil = &value + } + if value, ok := _c.mutation.TempUnschedulableReason(); ok { + _spec.SetField(account.FieldTempUnschedulableReason, field.TypeString, value) + _node.TempUnschedulableReason = &value + } if value, ok := _c.mutation.SessionWindowStart(); ok { _spec.SetField(account.FieldSessionWindowStart, field.TypeTime, value) _node.SessionWindowStart = &value @@ -1080,6 +1116,42 @@ func (u *AccountUpsert) ClearOverloadUntil() *AccountUpsert { return u } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (u *AccountUpsert) SetTempUnschedulableUntil(v time.Time) *AccountUpsert { + u.Set(account.FieldTempUnschedulableUntil, v) + return u +} + +// UpdateTempUnschedulableUntil sets the "temp_unschedulable_until" field to the value that was provided on create. +func (u *AccountUpsert) UpdateTempUnschedulableUntil() *AccountUpsert { + u.SetExcluded(account.FieldTempUnschedulableUntil) + return u +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (u *AccountUpsert) ClearTempUnschedulableUntil() *AccountUpsert { + u.SetNull(account.FieldTempUnschedulableUntil) + return u +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (u *AccountUpsert) SetTempUnschedulableReason(v string) *AccountUpsert { + u.Set(account.FieldTempUnschedulableReason, v) + return u +} + +// UpdateTempUnschedulableReason sets the "temp_unschedulable_reason" field to the value that was provided on create. +func (u *AccountUpsert) UpdateTempUnschedulableReason() *AccountUpsert { + u.SetExcluded(account.FieldTempUnschedulableReason) + return u +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (u *AccountUpsert) ClearTempUnschedulableReason() *AccountUpsert { + u.SetNull(account.FieldTempUnschedulableReason) + return u +} + // SetSessionWindowStart sets the "session_window_start" field. func (u *AccountUpsert) SetSessionWindowStart(v time.Time) *AccountUpsert { u.Set(account.FieldSessionWindowStart, v) @@ -1557,6 +1629,48 @@ func (u *AccountUpsertOne) ClearOverloadUntil() *AccountUpsertOne { }) } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (u *AccountUpsertOne) SetTempUnschedulableUntil(v time.Time) *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.SetTempUnschedulableUntil(v) + }) +} + +// UpdateTempUnschedulableUntil sets the "temp_unschedulable_until" field to the value that was provided on create. +func (u *AccountUpsertOne) UpdateTempUnschedulableUntil() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.UpdateTempUnschedulableUntil() + }) +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (u *AccountUpsertOne) ClearTempUnschedulableUntil() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.ClearTempUnschedulableUntil() + }) +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (u *AccountUpsertOne) SetTempUnschedulableReason(v string) *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.SetTempUnschedulableReason(v) + }) +} + +// UpdateTempUnschedulableReason sets the "temp_unschedulable_reason" field to the value that was provided on create. +func (u *AccountUpsertOne) UpdateTempUnschedulableReason() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.UpdateTempUnschedulableReason() + }) +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (u *AccountUpsertOne) ClearTempUnschedulableReason() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.ClearTempUnschedulableReason() + }) +} + // SetSessionWindowStart sets the "session_window_start" field. func (u *AccountUpsertOne) SetSessionWindowStart(v time.Time) *AccountUpsertOne { return u.Update(func(s *AccountUpsert) { @@ -2209,6 +2323,48 @@ func (u *AccountUpsertBulk) ClearOverloadUntil() *AccountUpsertBulk { }) } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (u *AccountUpsertBulk) SetTempUnschedulableUntil(v time.Time) *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.SetTempUnschedulableUntil(v) + }) +} + +// UpdateTempUnschedulableUntil sets the "temp_unschedulable_until" field to the value that was provided on create. +func (u *AccountUpsertBulk) UpdateTempUnschedulableUntil() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.UpdateTempUnschedulableUntil() + }) +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (u *AccountUpsertBulk) ClearTempUnschedulableUntil() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.ClearTempUnschedulableUntil() + }) +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (u *AccountUpsertBulk) SetTempUnschedulableReason(v string) *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.SetTempUnschedulableReason(v) + }) +} + +// UpdateTempUnschedulableReason sets the "temp_unschedulable_reason" field to the value that was provided on create. +func (u *AccountUpsertBulk) UpdateTempUnschedulableReason() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.UpdateTempUnschedulableReason() + }) +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (u *AccountUpsertBulk) ClearTempUnschedulableReason() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.ClearTempUnschedulableReason() + }) +} + // SetSessionWindowStart sets the "session_window_start" field. func (u *AccountUpsertBulk) SetSessionWindowStart(v time.Time) *AccountUpsertBulk { return u.Update(func(s *AccountUpsert) { diff --git a/backend/ent/account_update.go b/backend/ent/account_update.go index 63fab096..875888e0 100644 --- a/backend/ent/account_update.go +++ b/backend/ent/account_update.go @@ -376,6 +376,46 @@ func (_u *AccountUpdate) ClearOverloadUntil() *AccountUpdate { return _u } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (_u *AccountUpdate) SetTempUnschedulableUntil(v time.Time) *AccountUpdate { + _u.mutation.SetTempUnschedulableUntil(v) + return _u +} + +// SetNillableTempUnschedulableUntil sets the "temp_unschedulable_until" field if the given value is not nil. +func (_u *AccountUpdate) SetNillableTempUnschedulableUntil(v *time.Time) *AccountUpdate { + if v != nil { + _u.SetTempUnschedulableUntil(*v) + } + return _u +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (_u *AccountUpdate) ClearTempUnschedulableUntil() *AccountUpdate { + _u.mutation.ClearTempUnschedulableUntil() + return _u +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (_u *AccountUpdate) SetTempUnschedulableReason(v string) *AccountUpdate { + _u.mutation.SetTempUnschedulableReason(v) + return _u +} + +// SetNillableTempUnschedulableReason sets the "temp_unschedulable_reason" field if the given value is not nil. +func (_u *AccountUpdate) SetNillableTempUnschedulableReason(v *string) *AccountUpdate { + if v != nil { + _u.SetTempUnschedulableReason(*v) + } + return _u +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (_u *AccountUpdate) ClearTempUnschedulableReason() *AccountUpdate { + _u.mutation.ClearTempUnschedulableReason() + return _u +} + // SetSessionWindowStart sets the "session_window_start" field. func (_u *AccountUpdate) SetSessionWindowStart(v time.Time) *AccountUpdate { _u.mutation.SetSessionWindowStart(v) @@ -701,6 +741,18 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.OverloadUntilCleared() { _spec.ClearField(account.FieldOverloadUntil, field.TypeTime) } + if value, ok := _u.mutation.TempUnschedulableUntil(); ok { + _spec.SetField(account.FieldTempUnschedulableUntil, field.TypeTime, value) + } + if _u.mutation.TempUnschedulableUntilCleared() { + _spec.ClearField(account.FieldTempUnschedulableUntil, field.TypeTime) + } + if value, ok := _u.mutation.TempUnschedulableReason(); ok { + _spec.SetField(account.FieldTempUnschedulableReason, field.TypeString, value) + } + if _u.mutation.TempUnschedulableReasonCleared() { + _spec.ClearField(account.FieldTempUnschedulableReason, field.TypeString) + } if value, ok := _u.mutation.SessionWindowStart(); ok { _spec.SetField(account.FieldSessionWindowStart, field.TypeTime, value) } @@ -1215,6 +1267,46 @@ func (_u *AccountUpdateOne) ClearOverloadUntil() *AccountUpdateOne { return _u } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (_u *AccountUpdateOne) SetTempUnschedulableUntil(v time.Time) *AccountUpdateOne { + _u.mutation.SetTempUnschedulableUntil(v) + return _u +} + +// SetNillableTempUnschedulableUntil sets the "temp_unschedulable_until" field if the given value is not nil. +func (_u *AccountUpdateOne) SetNillableTempUnschedulableUntil(v *time.Time) *AccountUpdateOne { + if v != nil { + _u.SetTempUnschedulableUntil(*v) + } + return _u +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (_u *AccountUpdateOne) ClearTempUnschedulableUntil() *AccountUpdateOne { + _u.mutation.ClearTempUnschedulableUntil() + return _u +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (_u *AccountUpdateOne) SetTempUnschedulableReason(v string) *AccountUpdateOne { + _u.mutation.SetTempUnschedulableReason(v) + return _u +} + +// SetNillableTempUnschedulableReason sets the "temp_unschedulable_reason" field if the given value is not nil. +func (_u *AccountUpdateOne) SetNillableTempUnschedulableReason(v *string) *AccountUpdateOne { + if v != nil { + _u.SetTempUnschedulableReason(*v) + } + return _u +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (_u *AccountUpdateOne) ClearTempUnschedulableReason() *AccountUpdateOne { + _u.mutation.ClearTempUnschedulableReason() + return _u +} + // SetSessionWindowStart sets the "session_window_start" field. func (_u *AccountUpdateOne) SetSessionWindowStart(v time.Time) *AccountUpdateOne { _u.mutation.SetSessionWindowStart(v) @@ -1570,6 +1662,18 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er if _u.mutation.OverloadUntilCleared() { _spec.ClearField(account.FieldOverloadUntil, field.TypeTime) } + if value, ok := _u.mutation.TempUnschedulableUntil(); ok { + _spec.SetField(account.FieldTempUnschedulableUntil, field.TypeTime, value) + } + if _u.mutation.TempUnschedulableUntilCleared() { + _spec.ClearField(account.FieldTempUnschedulableUntil, field.TypeTime) + } + if value, ok := _u.mutation.TempUnschedulableReason(); ok { + _spec.SetField(account.FieldTempUnschedulableReason, field.TypeString, value) + } + if _u.mutation.TempUnschedulableReasonCleared() { + _spec.ClearField(account.FieldTempUnschedulableReason, field.TypeString) + } if value, ok := _u.mutation.SessionWindowStart(); ok { _spec.SetField(account.FieldSessionWindowStart, field.TypeTime, value) } diff --git a/backend/ent/apikey.go b/backend/ent/apikey.go index 91d71964..9ee660c2 100644 --- a/backend/ent/apikey.go +++ b/backend/ent/apikey.go @@ -36,6 +36,8 @@ type APIKey struct { GroupID *int64 `json:"group_id,omitempty"` // Status holds the value of the "status" field. Status string `json:"status,omitempty"` + // Last usage time of this API key + LastUsedAt *time.Time `json:"last_used_at,omitempty"` // Allowed IPs/CIDRs, e.g. ["192.168.1.100", "10.0.0.0/8"] IPWhitelist []string `json:"ip_whitelist,omitempty"` // Blocked IPs/CIDRs @@ -46,6 +48,24 @@ type APIKey struct { QuotaUsed float64 `json:"quota_used,omitempty"` // Expiration time for this API key (null = never expires) ExpiresAt *time.Time `json:"expires_at,omitempty"` + // Rate limit in USD per 5 hours (0 = unlimited) + RateLimit5h float64 `json:"rate_limit_5h,omitempty"` + // Rate limit in USD per day (0 = unlimited) + RateLimit1d float64 `json:"rate_limit_1d,omitempty"` + // Rate limit in USD per 7 days (0 = unlimited) + RateLimit7d float64 `json:"rate_limit_7d,omitempty"` + // Used amount in USD for the current 5h window + Usage5h float64 `json:"usage_5h,omitempty"` + // Used amount in USD for the current 1d window + Usage1d float64 `json:"usage_1d,omitempty"` + // Used amount in USD for the current 7d window + Usage7d float64 `json:"usage_7d,omitempty"` + // Start time of the current 5h rate limit window + Window5hStart *time.Time `json:"window_5h_start,omitempty"` + // Start time of the current 1d rate limit window + Window1dStart *time.Time `json:"window_1d_start,omitempty"` + // Start time of the current 7d rate limit window + Window7dStart *time.Time `json:"window_7d_start,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the APIKeyQuery when eager-loading is set. Edges APIKeyEdges `json:"edges"` @@ -103,13 +123,13 @@ func (*APIKey) scanValues(columns []string) ([]any, error) { switch columns[i] { case apikey.FieldIPWhitelist, apikey.FieldIPBlacklist: values[i] = new([]byte) - case apikey.FieldQuota, apikey.FieldQuotaUsed: + case apikey.FieldQuota, apikey.FieldQuotaUsed, apikey.FieldRateLimit5h, apikey.FieldRateLimit1d, apikey.FieldRateLimit7d, apikey.FieldUsage5h, apikey.FieldUsage1d, apikey.FieldUsage7d: values[i] = new(sql.NullFloat64) case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID: values[i] = new(sql.NullInt64) case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus: values[i] = new(sql.NullString) - case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldExpiresAt: + case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldLastUsedAt, apikey.FieldExpiresAt, apikey.FieldWindow5hStart, apikey.FieldWindow1dStart, apikey.FieldWindow7dStart: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -182,6 +202,13 @@ func (_m *APIKey) assignValues(columns []string, values []any) error { } else if value.Valid { _m.Status = value.String } + case apikey.FieldLastUsedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field last_used_at", values[i]) + } else if value.Valid { + _m.LastUsedAt = new(time.Time) + *_m.LastUsedAt = value.Time + } case apikey.FieldIPWhitelist: if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field ip_whitelist", values[i]) @@ -217,6 +244,63 @@ func (_m *APIKey) assignValues(columns []string, values []any) error { _m.ExpiresAt = new(time.Time) *_m.ExpiresAt = value.Time } + case apikey.FieldRateLimit5h: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field rate_limit_5h", values[i]) + } else if value.Valid { + _m.RateLimit5h = value.Float64 + } + case apikey.FieldRateLimit1d: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field rate_limit_1d", values[i]) + } else if value.Valid { + _m.RateLimit1d = value.Float64 + } + case apikey.FieldRateLimit7d: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field rate_limit_7d", values[i]) + } else if value.Valid { + _m.RateLimit7d = value.Float64 + } + case apikey.FieldUsage5h: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field usage_5h", values[i]) + } else if value.Valid { + _m.Usage5h = value.Float64 + } + case apikey.FieldUsage1d: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field usage_1d", values[i]) + } else if value.Valid { + _m.Usage1d = value.Float64 + } + case apikey.FieldUsage7d: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field usage_7d", values[i]) + } else if value.Valid { + _m.Usage7d = value.Float64 + } + case apikey.FieldWindow5hStart: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field window_5h_start", values[i]) + } else if value.Valid { + _m.Window5hStart = new(time.Time) + *_m.Window5hStart = value.Time + } + case apikey.FieldWindow1dStart: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field window_1d_start", values[i]) + } else if value.Valid { + _m.Window1dStart = new(time.Time) + *_m.Window1dStart = value.Time + } + case apikey.FieldWindow7dStart: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field window_7d_start", values[i]) + } else if value.Valid { + _m.Window7dStart = new(time.Time) + *_m.Window7dStart = value.Time + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -296,6 +380,11 @@ func (_m *APIKey) String() string { builder.WriteString("status=") builder.WriteString(_m.Status) builder.WriteString(", ") + if v := _m.LastUsedAt; v != nil { + builder.WriteString("last_used_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") builder.WriteString("ip_whitelist=") builder.WriteString(fmt.Sprintf("%v", _m.IPWhitelist)) builder.WriteString(", ") @@ -312,6 +401,39 @@ func (_m *APIKey) String() string { builder.WriteString("expires_at=") builder.WriteString(v.Format(time.ANSIC)) } + builder.WriteString(", ") + builder.WriteString("rate_limit_5h=") + builder.WriteString(fmt.Sprintf("%v", _m.RateLimit5h)) + builder.WriteString(", ") + builder.WriteString("rate_limit_1d=") + builder.WriteString(fmt.Sprintf("%v", _m.RateLimit1d)) + builder.WriteString(", ") + builder.WriteString("rate_limit_7d=") + builder.WriteString(fmt.Sprintf("%v", _m.RateLimit7d)) + builder.WriteString(", ") + builder.WriteString("usage_5h=") + builder.WriteString(fmt.Sprintf("%v", _m.Usage5h)) + builder.WriteString(", ") + builder.WriteString("usage_1d=") + builder.WriteString(fmt.Sprintf("%v", _m.Usage1d)) + builder.WriteString(", ") + builder.WriteString("usage_7d=") + builder.WriteString(fmt.Sprintf("%v", _m.Usage7d)) + builder.WriteString(", ") + if v := _m.Window5hStart; v != nil { + builder.WriteString("window_5h_start=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.Window1dStart; v != nil { + builder.WriteString("window_1d_start=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.Window7dStart; v != nil { + builder.WriteString("window_7d_start=") + builder.WriteString(v.Format(time.ANSIC)) + } builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/apikey/apikey.go b/backend/ent/apikey/apikey.go index ac2a6008..d398a027 100644 --- a/backend/ent/apikey/apikey.go +++ b/backend/ent/apikey/apikey.go @@ -31,6 +31,8 @@ const ( FieldGroupID = "group_id" // FieldStatus holds the string denoting the status field in the database. FieldStatus = "status" + // FieldLastUsedAt holds the string denoting the last_used_at field in the database. + FieldLastUsedAt = "last_used_at" // FieldIPWhitelist holds the string denoting the ip_whitelist field in the database. FieldIPWhitelist = "ip_whitelist" // FieldIPBlacklist holds the string denoting the ip_blacklist field in the database. @@ -41,6 +43,24 @@ const ( FieldQuotaUsed = "quota_used" // FieldExpiresAt holds the string denoting the expires_at field in the database. FieldExpiresAt = "expires_at" + // FieldRateLimit5h holds the string denoting the rate_limit_5h field in the database. + FieldRateLimit5h = "rate_limit_5h" + // FieldRateLimit1d holds the string denoting the rate_limit_1d field in the database. + FieldRateLimit1d = "rate_limit_1d" + // FieldRateLimit7d holds the string denoting the rate_limit_7d field in the database. + FieldRateLimit7d = "rate_limit_7d" + // FieldUsage5h holds the string denoting the usage_5h field in the database. + FieldUsage5h = "usage_5h" + // FieldUsage1d holds the string denoting the usage_1d field in the database. + FieldUsage1d = "usage_1d" + // FieldUsage7d holds the string denoting the usage_7d field in the database. + FieldUsage7d = "usage_7d" + // FieldWindow5hStart holds the string denoting the window_5h_start field in the database. + FieldWindow5hStart = "window_5h_start" + // FieldWindow1dStart holds the string denoting the window_1d_start field in the database. + FieldWindow1dStart = "window_1d_start" + // FieldWindow7dStart holds the string denoting the window_7d_start field in the database. + FieldWindow7dStart = "window_7d_start" // EdgeUser holds the string denoting the user edge name in mutations. EdgeUser = "user" // EdgeGroup holds the string denoting the group edge name in mutations. @@ -83,11 +103,21 @@ var Columns = []string{ FieldName, FieldGroupID, FieldStatus, + FieldLastUsedAt, FieldIPWhitelist, FieldIPBlacklist, FieldQuota, FieldQuotaUsed, FieldExpiresAt, + FieldRateLimit5h, + FieldRateLimit1d, + FieldRateLimit7d, + FieldUsage5h, + FieldUsage1d, + FieldUsage7d, + FieldWindow5hStart, + FieldWindow1dStart, + FieldWindow7dStart, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -126,6 +156,18 @@ var ( DefaultQuota float64 // DefaultQuotaUsed holds the default value on creation for the "quota_used" field. DefaultQuotaUsed float64 + // DefaultRateLimit5h holds the default value on creation for the "rate_limit_5h" field. + DefaultRateLimit5h float64 + // DefaultRateLimit1d holds the default value on creation for the "rate_limit_1d" field. + DefaultRateLimit1d float64 + // DefaultRateLimit7d holds the default value on creation for the "rate_limit_7d" field. + DefaultRateLimit7d float64 + // DefaultUsage5h holds the default value on creation for the "usage_5h" field. + DefaultUsage5h float64 + // DefaultUsage1d holds the default value on creation for the "usage_1d" field. + DefaultUsage1d float64 + // DefaultUsage7d holds the default value on creation for the "usage_7d" field. + DefaultUsage7d float64 ) // OrderOption defines the ordering options for the APIKey queries. @@ -176,6 +218,11 @@ func ByStatus(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldStatus, opts...).ToFunc() } +// ByLastUsedAt orders the results by the last_used_at field. +func ByLastUsedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastUsedAt, opts...).ToFunc() +} + // ByQuota orders the results by the quota field. func ByQuota(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldQuota, opts...).ToFunc() @@ -191,6 +238,51 @@ func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldExpiresAt, opts...).ToFunc() } +// ByRateLimit5h orders the results by the rate_limit_5h field. +func ByRateLimit5h(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRateLimit5h, opts...).ToFunc() +} + +// ByRateLimit1d orders the results by the rate_limit_1d field. +func ByRateLimit1d(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRateLimit1d, opts...).ToFunc() +} + +// ByRateLimit7d orders the results by the rate_limit_7d field. +func ByRateLimit7d(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRateLimit7d, opts...).ToFunc() +} + +// ByUsage5h orders the results by the usage_5h field. +func ByUsage5h(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUsage5h, opts...).ToFunc() +} + +// ByUsage1d orders the results by the usage_1d field. +func ByUsage1d(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUsage1d, opts...).ToFunc() +} + +// ByUsage7d orders the results by the usage_7d field. +func ByUsage7d(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUsage7d, opts...).ToFunc() +} + +// ByWindow5hStart orders the results by the window_5h_start field. +func ByWindow5hStart(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldWindow5hStart, opts...).ToFunc() +} + +// ByWindow1dStart orders the results by the window_1d_start field. +func ByWindow1dStart(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldWindow1dStart, opts...).ToFunc() +} + +// ByWindow7dStart orders the results by the window_7d_start field. +func ByWindow7dStart(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldWindow7dStart, opts...).ToFunc() +} + // ByUserField orders the results by user field. func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/apikey/where.go b/backend/ent/apikey/where.go index f54f44b7..edd2652b 100644 --- a/backend/ent/apikey/where.go +++ b/backend/ent/apikey/where.go @@ -95,6 +95,11 @@ func Status(v string) predicate.APIKey { return predicate.APIKey(sql.FieldEQ(FieldStatus, v)) } +// LastUsedAt applies equality check predicate on the "last_used_at" field. It's identical to LastUsedAtEQ. +func LastUsedAt(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldLastUsedAt, v)) +} + // Quota applies equality check predicate on the "quota" field. It's identical to QuotaEQ. func Quota(v float64) predicate.APIKey { return predicate.APIKey(sql.FieldEQ(FieldQuota, v)) @@ -110,6 +115,51 @@ func ExpiresAt(v time.Time) predicate.APIKey { return predicate.APIKey(sql.FieldEQ(FieldExpiresAt, v)) } +// RateLimit5h applies equality check predicate on the "rate_limit_5h" field. It's identical to RateLimit5hEQ. +func RateLimit5h(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldRateLimit5h, v)) +} + +// RateLimit1d applies equality check predicate on the "rate_limit_1d" field. It's identical to RateLimit1dEQ. +func RateLimit1d(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldRateLimit1d, v)) +} + +// RateLimit7d applies equality check predicate on the "rate_limit_7d" field. It's identical to RateLimit7dEQ. +func RateLimit7d(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldRateLimit7d, v)) +} + +// Usage5h applies equality check predicate on the "usage_5h" field. It's identical to Usage5hEQ. +func Usage5h(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldUsage5h, v)) +} + +// Usage1d applies equality check predicate on the "usage_1d" field. It's identical to Usage1dEQ. +func Usage1d(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldUsage1d, v)) +} + +// Usage7d applies equality check predicate on the "usage_7d" field. It's identical to Usage7dEQ. +func Usage7d(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldUsage7d, v)) +} + +// Window5hStart applies equality check predicate on the "window_5h_start" field. It's identical to Window5hStartEQ. +func Window5hStart(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldWindow5hStart, v)) +} + +// Window1dStart applies equality check predicate on the "window_1d_start" field. It's identical to Window1dStartEQ. +func Window1dStart(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldWindow1dStart, v)) +} + +// Window7dStart applies equality check predicate on the "window_7d_start" field. It's identical to Window7dStartEQ. +func Window7dStart(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldWindow7dStart, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.APIKey { return predicate.APIKey(sql.FieldEQ(FieldCreatedAt, v)) @@ -485,6 +535,56 @@ func StatusContainsFold(v string) predicate.APIKey { return predicate.APIKey(sql.FieldContainsFold(FieldStatus, v)) } +// LastUsedAtEQ applies the EQ predicate on the "last_used_at" field. +func LastUsedAtEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldLastUsedAt, v)) +} + +// LastUsedAtNEQ applies the NEQ predicate on the "last_used_at" field. +func LastUsedAtNEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldLastUsedAt, v)) +} + +// LastUsedAtIn applies the In predicate on the "last_used_at" field. +func LastUsedAtIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldLastUsedAt, vs...)) +} + +// LastUsedAtNotIn applies the NotIn predicate on the "last_used_at" field. +func LastUsedAtNotIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldLastUsedAt, vs...)) +} + +// LastUsedAtGT applies the GT predicate on the "last_used_at" field. +func LastUsedAtGT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldLastUsedAt, v)) +} + +// LastUsedAtGTE applies the GTE predicate on the "last_used_at" field. +func LastUsedAtGTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldLastUsedAt, v)) +} + +// LastUsedAtLT applies the LT predicate on the "last_used_at" field. +func LastUsedAtLT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldLastUsedAt, v)) +} + +// LastUsedAtLTE applies the LTE predicate on the "last_used_at" field. +func LastUsedAtLTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldLastUsedAt, v)) +} + +// LastUsedAtIsNil applies the IsNil predicate on the "last_used_at" field. +func LastUsedAtIsNil() predicate.APIKey { + return predicate.APIKey(sql.FieldIsNull(FieldLastUsedAt)) +} + +// LastUsedAtNotNil applies the NotNil predicate on the "last_used_at" field. +func LastUsedAtNotNil() predicate.APIKey { + return predicate.APIKey(sql.FieldNotNull(FieldLastUsedAt)) +} + // IPWhitelistIsNil applies the IsNil predicate on the "ip_whitelist" field. func IPWhitelistIsNil() predicate.APIKey { return predicate.APIKey(sql.FieldIsNull(FieldIPWhitelist)) @@ -635,6 +735,396 @@ func ExpiresAtNotNil() predicate.APIKey { return predicate.APIKey(sql.FieldNotNull(FieldExpiresAt)) } +// RateLimit5hEQ applies the EQ predicate on the "rate_limit_5h" field. +func RateLimit5hEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldRateLimit5h, v)) +} + +// RateLimit5hNEQ applies the NEQ predicate on the "rate_limit_5h" field. +func RateLimit5hNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldRateLimit5h, v)) +} + +// RateLimit5hIn applies the In predicate on the "rate_limit_5h" field. +func RateLimit5hIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldRateLimit5h, vs...)) +} + +// RateLimit5hNotIn applies the NotIn predicate on the "rate_limit_5h" field. +func RateLimit5hNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldRateLimit5h, vs...)) +} + +// RateLimit5hGT applies the GT predicate on the "rate_limit_5h" field. +func RateLimit5hGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldRateLimit5h, v)) +} + +// RateLimit5hGTE applies the GTE predicate on the "rate_limit_5h" field. +func RateLimit5hGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldRateLimit5h, v)) +} + +// RateLimit5hLT applies the LT predicate on the "rate_limit_5h" field. +func RateLimit5hLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldRateLimit5h, v)) +} + +// RateLimit5hLTE applies the LTE predicate on the "rate_limit_5h" field. +func RateLimit5hLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldRateLimit5h, v)) +} + +// RateLimit1dEQ applies the EQ predicate on the "rate_limit_1d" field. +func RateLimit1dEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldRateLimit1d, v)) +} + +// RateLimit1dNEQ applies the NEQ predicate on the "rate_limit_1d" field. +func RateLimit1dNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldRateLimit1d, v)) +} + +// RateLimit1dIn applies the In predicate on the "rate_limit_1d" field. +func RateLimit1dIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldRateLimit1d, vs...)) +} + +// RateLimit1dNotIn applies the NotIn predicate on the "rate_limit_1d" field. +func RateLimit1dNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldRateLimit1d, vs...)) +} + +// RateLimit1dGT applies the GT predicate on the "rate_limit_1d" field. +func RateLimit1dGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldRateLimit1d, v)) +} + +// RateLimit1dGTE applies the GTE predicate on the "rate_limit_1d" field. +func RateLimit1dGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldRateLimit1d, v)) +} + +// RateLimit1dLT applies the LT predicate on the "rate_limit_1d" field. +func RateLimit1dLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldRateLimit1d, v)) +} + +// RateLimit1dLTE applies the LTE predicate on the "rate_limit_1d" field. +func RateLimit1dLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldRateLimit1d, v)) +} + +// RateLimit7dEQ applies the EQ predicate on the "rate_limit_7d" field. +func RateLimit7dEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldRateLimit7d, v)) +} + +// RateLimit7dNEQ applies the NEQ predicate on the "rate_limit_7d" field. +func RateLimit7dNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldRateLimit7d, v)) +} + +// RateLimit7dIn applies the In predicate on the "rate_limit_7d" field. +func RateLimit7dIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldRateLimit7d, vs...)) +} + +// RateLimit7dNotIn applies the NotIn predicate on the "rate_limit_7d" field. +func RateLimit7dNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldRateLimit7d, vs...)) +} + +// RateLimit7dGT applies the GT predicate on the "rate_limit_7d" field. +func RateLimit7dGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldRateLimit7d, v)) +} + +// RateLimit7dGTE applies the GTE predicate on the "rate_limit_7d" field. +func RateLimit7dGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldRateLimit7d, v)) +} + +// RateLimit7dLT applies the LT predicate on the "rate_limit_7d" field. +func RateLimit7dLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldRateLimit7d, v)) +} + +// RateLimit7dLTE applies the LTE predicate on the "rate_limit_7d" field. +func RateLimit7dLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldRateLimit7d, v)) +} + +// Usage5hEQ applies the EQ predicate on the "usage_5h" field. +func Usage5hEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldUsage5h, v)) +} + +// Usage5hNEQ applies the NEQ predicate on the "usage_5h" field. +func Usage5hNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldUsage5h, v)) +} + +// Usage5hIn applies the In predicate on the "usage_5h" field. +func Usage5hIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldUsage5h, vs...)) +} + +// Usage5hNotIn applies the NotIn predicate on the "usage_5h" field. +func Usage5hNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldUsage5h, vs...)) +} + +// Usage5hGT applies the GT predicate on the "usage_5h" field. +func Usage5hGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldUsage5h, v)) +} + +// Usage5hGTE applies the GTE predicate on the "usage_5h" field. +func Usage5hGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldUsage5h, v)) +} + +// Usage5hLT applies the LT predicate on the "usage_5h" field. +func Usage5hLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldUsage5h, v)) +} + +// Usage5hLTE applies the LTE predicate on the "usage_5h" field. +func Usage5hLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldUsage5h, v)) +} + +// Usage1dEQ applies the EQ predicate on the "usage_1d" field. +func Usage1dEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldUsage1d, v)) +} + +// Usage1dNEQ applies the NEQ predicate on the "usage_1d" field. +func Usage1dNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldUsage1d, v)) +} + +// Usage1dIn applies the In predicate on the "usage_1d" field. +func Usage1dIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldUsage1d, vs...)) +} + +// Usage1dNotIn applies the NotIn predicate on the "usage_1d" field. +func Usage1dNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldUsage1d, vs...)) +} + +// Usage1dGT applies the GT predicate on the "usage_1d" field. +func Usage1dGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldUsage1d, v)) +} + +// Usage1dGTE applies the GTE predicate on the "usage_1d" field. +func Usage1dGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldUsage1d, v)) +} + +// Usage1dLT applies the LT predicate on the "usage_1d" field. +func Usage1dLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldUsage1d, v)) +} + +// Usage1dLTE applies the LTE predicate on the "usage_1d" field. +func Usage1dLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldUsage1d, v)) +} + +// Usage7dEQ applies the EQ predicate on the "usage_7d" field. +func Usage7dEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldUsage7d, v)) +} + +// Usage7dNEQ applies the NEQ predicate on the "usage_7d" field. +func Usage7dNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldUsage7d, v)) +} + +// Usage7dIn applies the In predicate on the "usage_7d" field. +func Usage7dIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldUsage7d, vs...)) +} + +// Usage7dNotIn applies the NotIn predicate on the "usage_7d" field. +func Usage7dNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldUsage7d, vs...)) +} + +// Usage7dGT applies the GT predicate on the "usage_7d" field. +func Usage7dGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldUsage7d, v)) +} + +// Usage7dGTE applies the GTE predicate on the "usage_7d" field. +func Usage7dGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldUsage7d, v)) +} + +// Usage7dLT applies the LT predicate on the "usage_7d" field. +func Usage7dLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldUsage7d, v)) +} + +// Usage7dLTE applies the LTE predicate on the "usage_7d" field. +func Usage7dLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldUsage7d, v)) +} + +// Window5hStartEQ applies the EQ predicate on the "window_5h_start" field. +func Window5hStartEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldWindow5hStart, v)) +} + +// Window5hStartNEQ applies the NEQ predicate on the "window_5h_start" field. +func Window5hStartNEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldWindow5hStart, v)) +} + +// Window5hStartIn applies the In predicate on the "window_5h_start" field. +func Window5hStartIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldWindow5hStart, vs...)) +} + +// Window5hStartNotIn applies the NotIn predicate on the "window_5h_start" field. +func Window5hStartNotIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldWindow5hStart, vs...)) +} + +// Window5hStartGT applies the GT predicate on the "window_5h_start" field. +func Window5hStartGT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldWindow5hStart, v)) +} + +// Window5hStartGTE applies the GTE predicate on the "window_5h_start" field. +func Window5hStartGTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldWindow5hStart, v)) +} + +// Window5hStartLT applies the LT predicate on the "window_5h_start" field. +func Window5hStartLT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldWindow5hStart, v)) +} + +// Window5hStartLTE applies the LTE predicate on the "window_5h_start" field. +func Window5hStartLTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldWindow5hStart, v)) +} + +// Window5hStartIsNil applies the IsNil predicate on the "window_5h_start" field. +func Window5hStartIsNil() predicate.APIKey { + return predicate.APIKey(sql.FieldIsNull(FieldWindow5hStart)) +} + +// Window5hStartNotNil applies the NotNil predicate on the "window_5h_start" field. +func Window5hStartNotNil() predicate.APIKey { + return predicate.APIKey(sql.FieldNotNull(FieldWindow5hStart)) +} + +// Window1dStartEQ applies the EQ predicate on the "window_1d_start" field. +func Window1dStartEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldWindow1dStart, v)) +} + +// Window1dStartNEQ applies the NEQ predicate on the "window_1d_start" field. +func Window1dStartNEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldWindow1dStart, v)) +} + +// Window1dStartIn applies the In predicate on the "window_1d_start" field. +func Window1dStartIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldWindow1dStart, vs...)) +} + +// Window1dStartNotIn applies the NotIn predicate on the "window_1d_start" field. +func Window1dStartNotIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldWindow1dStart, vs...)) +} + +// Window1dStartGT applies the GT predicate on the "window_1d_start" field. +func Window1dStartGT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldWindow1dStart, v)) +} + +// Window1dStartGTE applies the GTE predicate on the "window_1d_start" field. +func Window1dStartGTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldWindow1dStart, v)) +} + +// Window1dStartLT applies the LT predicate on the "window_1d_start" field. +func Window1dStartLT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldWindow1dStart, v)) +} + +// Window1dStartLTE applies the LTE predicate on the "window_1d_start" field. +func Window1dStartLTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldWindow1dStart, v)) +} + +// Window1dStartIsNil applies the IsNil predicate on the "window_1d_start" field. +func Window1dStartIsNil() predicate.APIKey { + return predicate.APIKey(sql.FieldIsNull(FieldWindow1dStart)) +} + +// Window1dStartNotNil applies the NotNil predicate on the "window_1d_start" field. +func Window1dStartNotNil() predicate.APIKey { + return predicate.APIKey(sql.FieldNotNull(FieldWindow1dStart)) +} + +// Window7dStartEQ applies the EQ predicate on the "window_7d_start" field. +func Window7dStartEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldWindow7dStart, v)) +} + +// Window7dStartNEQ applies the NEQ predicate on the "window_7d_start" field. +func Window7dStartNEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldWindow7dStart, v)) +} + +// Window7dStartIn applies the In predicate on the "window_7d_start" field. +func Window7dStartIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldWindow7dStart, vs...)) +} + +// Window7dStartNotIn applies the NotIn predicate on the "window_7d_start" field. +func Window7dStartNotIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldWindow7dStart, vs...)) +} + +// Window7dStartGT applies the GT predicate on the "window_7d_start" field. +func Window7dStartGT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldWindow7dStart, v)) +} + +// Window7dStartGTE applies the GTE predicate on the "window_7d_start" field. +func Window7dStartGTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldWindow7dStart, v)) +} + +// Window7dStartLT applies the LT predicate on the "window_7d_start" field. +func Window7dStartLT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldWindow7dStart, v)) +} + +// Window7dStartLTE applies the LTE predicate on the "window_7d_start" field. +func Window7dStartLTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldWindow7dStart, v)) +} + +// Window7dStartIsNil applies the IsNil predicate on the "window_7d_start" field. +func Window7dStartIsNil() predicate.APIKey { + return predicate.APIKey(sql.FieldIsNull(FieldWindow7dStart)) +} + +// Window7dStartNotNil applies the NotNil predicate on the "window_7d_start" field. +func Window7dStartNotNil() predicate.APIKey { + return predicate.APIKey(sql.FieldNotNull(FieldWindow7dStart)) +} + // HasUser applies the HasEdge predicate on the "user" edge. func HasUser() predicate.APIKey { return predicate.APIKey(func(s *sql.Selector) { diff --git a/backend/ent/apikey_create.go b/backend/ent/apikey_create.go index 71540975..4ec8aeaa 100644 --- a/backend/ent/apikey_create.go +++ b/backend/ent/apikey_create.go @@ -113,6 +113,20 @@ func (_c *APIKeyCreate) SetNillableStatus(v *string) *APIKeyCreate { return _c } +// SetLastUsedAt sets the "last_used_at" field. +func (_c *APIKeyCreate) SetLastUsedAt(v time.Time) *APIKeyCreate { + _c.mutation.SetLastUsedAt(v) + return _c +} + +// SetNillableLastUsedAt sets the "last_used_at" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableLastUsedAt(v *time.Time) *APIKeyCreate { + if v != nil { + _c.SetLastUsedAt(*v) + } + return _c +} + // SetIPWhitelist sets the "ip_whitelist" field. func (_c *APIKeyCreate) SetIPWhitelist(v []string) *APIKeyCreate { _c.mutation.SetIPWhitelist(v) @@ -167,6 +181,132 @@ func (_c *APIKeyCreate) SetNillableExpiresAt(v *time.Time) *APIKeyCreate { return _c } +// SetRateLimit5h sets the "rate_limit_5h" field. +func (_c *APIKeyCreate) SetRateLimit5h(v float64) *APIKeyCreate { + _c.mutation.SetRateLimit5h(v) + return _c +} + +// SetNillableRateLimit5h sets the "rate_limit_5h" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableRateLimit5h(v *float64) *APIKeyCreate { + if v != nil { + _c.SetRateLimit5h(*v) + } + return _c +} + +// SetRateLimit1d sets the "rate_limit_1d" field. +func (_c *APIKeyCreate) SetRateLimit1d(v float64) *APIKeyCreate { + _c.mutation.SetRateLimit1d(v) + return _c +} + +// SetNillableRateLimit1d sets the "rate_limit_1d" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableRateLimit1d(v *float64) *APIKeyCreate { + if v != nil { + _c.SetRateLimit1d(*v) + } + return _c +} + +// SetRateLimit7d sets the "rate_limit_7d" field. +func (_c *APIKeyCreate) SetRateLimit7d(v float64) *APIKeyCreate { + _c.mutation.SetRateLimit7d(v) + return _c +} + +// SetNillableRateLimit7d sets the "rate_limit_7d" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableRateLimit7d(v *float64) *APIKeyCreate { + if v != nil { + _c.SetRateLimit7d(*v) + } + return _c +} + +// SetUsage5h sets the "usage_5h" field. +func (_c *APIKeyCreate) SetUsage5h(v float64) *APIKeyCreate { + _c.mutation.SetUsage5h(v) + return _c +} + +// SetNillableUsage5h sets the "usage_5h" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableUsage5h(v *float64) *APIKeyCreate { + if v != nil { + _c.SetUsage5h(*v) + } + return _c +} + +// SetUsage1d sets the "usage_1d" field. +func (_c *APIKeyCreate) SetUsage1d(v float64) *APIKeyCreate { + _c.mutation.SetUsage1d(v) + return _c +} + +// SetNillableUsage1d sets the "usage_1d" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableUsage1d(v *float64) *APIKeyCreate { + if v != nil { + _c.SetUsage1d(*v) + } + return _c +} + +// SetUsage7d sets the "usage_7d" field. +func (_c *APIKeyCreate) SetUsage7d(v float64) *APIKeyCreate { + _c.mutation.SetUsage7d(v) + return _c +} + +// SetNillableUsage7d sets the "usage_7d" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableUsage7d(v *float64) *APIKeyCreate { + if v != nil { + _c.SetUsage7d(*v) + } + return _c +} + +// SetWindow5hStart sets the "window_5h_start" field. +func (_c *APIKeyCreate) SetWindow5hStart(v time.Time) *APIKeyCreate { + _c.mutation.SetWindow5hStart(v) + return _c +} + +// SetNillableWindow5hStart sets the "window_5h_start" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableWindow5hStart(v *time.Time) *APIKeyCreate { + if v != nil { + _c.SetWindow5hStart(*v) + } + return _c +} + +// SetWindow1dStart sets the "window_1d_start" field. +func (_c *APIKeyCreate) SetWindow1dStart(v time.Time) *APIKeyCreate { + _c.mutation.SetWindow1dStart(v) + return _c +} + +// SetNillableWindow1dStart sets the "window_1d_start" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableWindow1dStart(v *time.Time) *APIKeyCreate { + if v != nil { + _c.SetWindow1dStart(*v) + } + return _c +} + +// SetWindow7dStart sets the "window_7d_start" field. +func (_c *APIKeyCreate) SetWindow7dStart(v time.Time) *APIKeyCreate { + _c.mutation.SetWindow7dStart(v) + return _c +} + +// SetNillableWindow7dStart sets the "window_7d_start" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableWindow7dStart(v *time.Time) *APIKeyCreate { + if v != nil { + _c.SetWindow7dStart(*v) + } + return _c +} + // SetUser sets the "user" edge to the User entity. func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate { return _c.SetUserID(v.ID) @@ -255,6 +395,30 @@ func (_c *APIKeyCreate) defaults() error { v := apikey.DefaultQuotaUsed _c.mutation.SetQuotaUsed(v) } + if _, ok := _c.mutation.RateLimit5h(); !ok { + v := apikey.DefaultRateLimit5h + _c.mutation.SetRateLimit5h(v) + } + if _, ok := _c.mutation.RateLimit1d(); !ok { + v := apikey.DefaultRateLimit1d + _c.mutation.SetRateLimit1d(v) + } + if _, ok := _c.mutation.RateLimit7d(); !ok { + v := apikey.DefaultRateLimit7d + _c.mutation.SetRateLimit7d(v) + } + if _, ok := _c.mutation.Usage5h(); !ok { + v := apikey.DefaultUsage5h + _c.mutation.SetUsage5h(v) + } + if _, ok := _c.mutation.Usage1d(); !ok { + v := apikey.DefaultUsage1d + _c.mutation.SetUsage1d(v) + } + if _, ok := _c.mutation.Usage7d(); !ok { + v := apikey.DefaultUsage7d + _c.mutation.SetUsage7d(v) + } return nil } @@ -299,6 +463,24 @@ func (_c *APIKeyCreate) check() error { if _, ok := _c.mutation.QuotaUsed(); !ok { return &ValidationError{Name: "quota_used", err: errors.New(`ent: missing required field "APIKey.quota_used"`)} } + if _, ok := _c.mutation.RateLimit5h(); !ok { + return &ValidationError{Name: "rate_limit_5h", err: errors.New(`ent: missing required field "APIKey.rate_limit_5h"`)} + } + if _, ok := _c.mutation.RateLimit1d(); !ok { + return &ValidationError{Name: "rate_limit_1d", err: errors.New(`ent: missing required field "APIKey.rate_limit_1d"`)} + } + if _, ok := _c.mutation.RateLimit7d(); !ok { + return &ValidationError{Name: "rate_limit_7d", err: errors.New(`ent: missing required field "APIKey.rate_limit_7d"`)} + } + if _, ok := _c.mutation.Usage5h(); !ok { + return &ValidationError{Name: "usage_5h", err: errors.New(`ent: missing required field "APIKey.usage_5h"`)} + } + if _, ok := _c.mutation.Usage1d(); !ok { + return &ValidationError{Name: "usage_1d", err: errors.New(`ent: missing required field "APIKey.usage_1d"`)} + } + if _, ok := _c.mutation.Usage7d(); !ok { + return &ValidationError{Name: "usage_7d", err: errors.New(`ent: missing required field "APIKey.usage_7d"`)} + } if len(_c.mutation.UserIDs()) == 0 { return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "APIKey.user"`)} } @@ -353,6 +535,10 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) { _spec.SetField(apikey.FieldStatus, field.TypeString, value) _node.Status = value } + if value, ok := _c.mutation.LastUsedAt(); ok { + _spec.SetField(apikey.FieldLastUsedAt, field.TypeTime, value) + _node.LastUsedAt = &value + } if value, ok := _c.mutation.IPWhitelist(); ok { _spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value) _node.IPWhitelist = value @@ -373,6 +559,42 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) { _spec.SetField(apikey.FieldExpiresAt, field.TypeTime, value) _node.ExpiresAt = &value } + if value, ok := _c.mutation.RateLimit5h(); ok { + _spec.SetField(apikey.FieldRateLimit5h, field.TypeFloat64, value) + _node.RateLimit5h = value + } + if value, ok := _c.mutation.RateLimit1d(); ok { + _spec.SetField(apikey.FieldRateLimit1d, field.TypeFloat64, value) + _node.RateLimit1d = value + } + if value, ok := _c.mutation.RateLimit7d(); ok { + _spec.SetField(apikey.FieldRateLimit7d, field.TypeFloat64, value) + _node.RateLimit7d = value + } + if value, ok := _c.mutation.Usage5h(); ok { + _spec.SetField(apikey.FieldUsage5h, field.TypeFloat64, value) + _node.Usage5h = value + } + if value, ok := _c.mutation.Usage1d(); ok { + _spec.SetField(apikey.FieldUsage1d, field.TypeFloat64, value) + _node.Usage1d = value + } + if value, ok := _c.mutation.Usage7d(); ok { + _spec.SetField(apikey.FieldUsage7d, field.TypeFloat64, value) + _node.Usage7d = value + } + if value, ok := _c.mutation.Window5hStart(); ok { + _spec.SetField(apikey.FieldWindow5hStart, field.TypeTime, value) + _node.Window5hStart = &value + } + if value, ok := _c.mutation.Window1dStart(); ok { + _spec.SetField(apikey.FieldWindow1dStart, field.TypeTime, value) + _node.Window1dStart = &value + } + if value, ok := _c.mutation.Window7dStart(); ok { + _spec.SetField(apikey.FieldWindow7dStart, field.TypeTime, value) + _node.Window7dStart = &value + } if nodes := _c.mutation.UserIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -571,6 +793,24 @@ func (u *APIKeyUpsert) UpdateStatus() *APIKeyUpsert { return u } +// SetLastUsedAt sets the "last_used_at" field. +func (u *APIKeyUpsert) SetLastUsedAt(v time.Time) *APIKeyUpsert { + u.Set(apikey.FieldLastUsedAt, v) + return u +} + +// UpdateLastUsedAt sets the "last_used_at" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateLastUsedAt() *APIKeyUpsert { + u.SetExcluded(apikey.FieldLastUsedAt) + return u +} + +// ClearLastUsedAt clears the value of the "last_used_at" field. +func (u *APIKeyUpsert) ClearLastUsedAt() *APIKeyUpsert { + u.SetNull(apikey.FieldLastUsedAt) + return u +} + // SetIPWhitelist sets the "ip_whitelist" field. func (u *APIKeyUpsert) SetIPWhitelist(v []string) *APIKeyUpsert { u.Set(apikey.FieldIPWhitelist, v) @@ -661,6 +901,168 @@ func (u *APIKeyUpsert) ClearExpiresAt() *APIKeyUpsert { return u } +// SetRateLimit5h sets the "rate_limit_5h" field. +func (u *APIKeyUpsert) SetRateLimit5h(v float64) *APIKeyUpsert { + u.Set(apikey.FieldRateLimit5h, v) + return u +} + +// UpdateRateLimit5h sets the "rate_limit_5h" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateRateLimit5h() *APIKeyUpsert { + u.SetExcluded(apikey.FieldRateLimit5h) + return u +} + +// AddRateLimit5h adds v to the "rate_limit_5h" field. +func (u *APIKeyUpsert) AddRateLimit5h(v float64) *APIKeyUpsert { + u.Add(apikey.FieldRateLimit5h, v) + return u +} + +// SetRateLimit1d sets the "rate_limit_1d" field. +func (u *APIKeyUpsert) SetRateLimit1d(v float64) *APIKeyUpsert { + u.Set(apikey.FieldRateLimit1d, v) + return u +} + +// UpdateRateLimit1d sets the "rate_limit_1d" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateRateLimit1d() *APIKeyUpsert { + u.SetExcluded(apikey.FieldRateLimit1d) + return u +} + +// AddRateLimit1d adds v to the "rate_limit_1d" field. +func (u *APIKeyUpsert) AddRateLimit1d(v float64) *APIKeyUpsert { + u.Add(apikey.FieldRateLimit1d, v) + return u +} + +// SetRateLimit7d sets the "rate_limit_7d" field. +func (u *APIKeyUpsert) SetRateLimit7d(v float64) *APIKeyUpsert { + u.Set(apikey.FieldRateLimit7d, v) + return u +} + +// UpdateRateLimit7d sets the "rate_limit_7d" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateRateLimit7d() *APIKeyUpsert { + u.SetExcluded(apikey.FieldRateLimit7d) + return u +} + +// AddRateLimit7d adds v to the "rate_limit_7d" field. +func (u *APIKeyUpsert) AddRateLimit7d(v float64) *APIKeyUpsert { + u.Add(apikey.FieldRateLimit7d, v) + return u +} + +// SetUsage5h sets the "usage_5h" field. +func (u *APIKeyUpsert) SetUsage5h(v float64) *APIKeyUpsert { + u.Set(apikey.FieldUsage5h, v) + return u +} + +// UpdateUsage5h sets the "usage_5h" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateUsage5h() *APIKeyUpsert { + u.SetExcluded(apikey.FieldUsage5h) + return u +} + +// AddUsage5h adds v to the "usage_5h" field. +func (u *APIKeyUpsert) AddUsage5h(v float64) *APIKeyUpsert { + u.Add(apikey.FieldUsage5h, v) + return u +} + +// SetUsage1d sets the "usage_1d" field. +func (u *APIKeyUpsert) SetUsage1d(v float64) *APIKeyUpsert { + u.Set(apikey.FieldUsage1d, v) + return u +} + +// UpdateUsage1d sets the "usage_1d" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateUsage1d() *APIKeyUpsert { + u.SetExcluded(apikey.FieldUsage1d) + return u +} + +// AddUsage1d adds v to the "usage_1d" field. +func (u *APIKeyUpsert) AddUsage1d(v float64) *APIKeyUpsert { + u.Add(apikey.FieldUsage1d, v) + return u +} + +// SetUsage7d sets the "usage_7d" field. +func (u *APIKeyUpsert) SetUsage7d(v float64) *APIKeyUpsert { + u.Set(apikey.FieldUsage7d, v) + return u +} + +// UpdateUsage7d sets the "usage_7d" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateUsage7d() *APIKeyUpsert { + u.SetExcluded(apikey.FieldUsage7d) + return u +} + +// AddUsage7d adds v to the "usage_7d" field. +func (u *APIKeyUpsert) AddUsage7d(v float64) *APIKeyUpsert { + u.Add(apikey.FieldUsage7d, v) + return u +} + +// SetWindow5hStart sets the "window_5h_start" field. +func (u *APIKeyUpsert) SetWindow5hStart(v time.Time) *APIKeyUpsert { + u.Set(apikey.FieldWindow5hStart, v) + return u +} + +// UpdateWindow5hStart sets the "window_5h_start" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateWindow5hStart() *APIKeyUpsert { + u.SetExcluded(apikey.FieldWindow5hStart) + return u +} + +// ClearWindow5hStart clears the value of the "window_5h_start" field. +func (u *APIKeyUpsert) ClearWindow5hStart() *APIKeyUpsert { + u.SetNull(apikey.FieldWindow5hStart) + return u +} + +// SetWindow1dStart sets the "window_1d_start" field. +func (u *APIKeyUpsert) SetWindow1dStart(v time.Time) *APIKeyUpsert { + u.Set(apikey.FieldWindow1dStart, v) + return u +} + +// UpdateWindow1dStart sets the "window_1d_start" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateWindow1dStart() *APIKeyUpsert { + u.SetExcluded(apikey.FieldWindow1dStart) + return u +} + +// ClearWindow1dStart clears the value of the "window_1d_start" field. +func (u *APIKeyUpsert) ClearWindow1dStart() *APIKeyUpsert { + u.SetNull(apikey.FieldWindow1dStart) + return u +} + +// SetWindow7dStart sets the "window_7d_start" field. +func (u *APIKeyUpsert) SetWindow7dStart(v time.Time) *APIKeyUpsert { + u.Set(apikey.FieldWindow7dStart, v) + return u +} + +// UpdateWindow7dStart sets the "window_7d_start" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateWindow7dStart() *APIKeyUpsert { + u.SetExcluded(apikey.FieldWindow7dStart) + return u +} + +// ClearWindow7dStart clears the value of the "window_7d_start" field. +func (u *APIKeyUpsert) ClearWindow7dStart() *APIKeyUpsert { + u.SetNull(apikey.FieldWindow7dStart) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -818,6 +1220,27 @@ func (u *APIKeyUpsertOne) UpdateStatus() *APIKeyUpsertOne { }) } +// SetLastUsedAt sets the "last_used_at" field. +func (u *APIKeyUpsertOne) SetLastUsedAt(v time.Time) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetLastUsedAt(v) + }) +} + +// UpdateLastUsedAt sets the "last_used_at" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateLastUsedAt() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateLastUsedAt() + }) +} + +// ClearLastUsedAt clears the value of the "last_used_at" field. +func (u *APIKeyUpsertOne) ClearLastUsedAt() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.ClearLastUsedAt() + }) +} + // SetIPWhitelist sets the "ip_whitelist" field. func (u *APIKeyUpsertOne) SetIPWhitelist(v []string) *APIKeyUpsertOne { return u.Update(func(s *APIKeyUpsert) { @@ -923,6 +1346,195 @@ func (u *APIKeyUpsertOne) ClearExpiresAt() *APIKeyUpsertOne { }) } +// SetRateLimit5h sets the "rate_limit_5h" field. +func (u *APIKeyUpsertOne) SetRateLimit5h(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetRateLimit5h(v) + }) +} + +// AddRateLimit5h adds v to the "rate_limit_5h" field. +func (u *APIKeyUpsertOne) AddRateLimit5h(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddRateLimit5h(v) + }) +} + +// UpdateRateLimit5h sets the "rate_limit_5h" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateRateLimit5h() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateRateLimit5h() + }) +} + +// SetRateLimit1d sets the "rate_limit_1d" field. +func (u *APIKeyUpsertOne) SetRateLimit1d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetRateLimit1d(v) + }) +} + +// AddRateLimit1d adds v to the "rate_limit_1d" field. +func (u *APIKeyUpsertOne) AddRateLimit1d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddRateLimit1d(v) + }) +} + +// UpdateRateLimit1d sets the "rate_limit_1d" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateRateLimit1d() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateRateLimit1d() + }) +} + +// SetRateLimit7d sets the "rate_limit_7d" field. +func (u *APIKeyUpsertOne) SetRateLimit7d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetRateLimit7d(v) + }) +} + +// AddRateLimit7d adds v to the "rate_limit_7d" field. +func (u *APIKeyUpsertOne) AddRateLimit7d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddRateLimit7d(v) + }) +} + +// UpdateRateLimit7d sets the "rate_limit_7d" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateRateLimit7d() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateRateLimit7d() + }) +} + +// SetUsage5h sets the "usage_5h" field. +func (u *APIKeyUpsertOne) SetUsage5h(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetUsage5h(v) + }) +} + +// AddUsage5h adds v to the "usage_5h" field. +func (u *APIKeyUpsertOne) AddUsage5h(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddUsage5h(v) + }) +} + +// UpdateUsage5h sets the "usage_5h" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateUsage5h() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateUsage5h() + }) +} + +// SetUsage1d sets the "usage_1d" field. +func (u *APIKeyUpsertOne) SetUsage1d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetUsage1d(v) + }) +} + +// AddUsage1d adds v to the "usage_1d" field. +func (u *APIKeyUpsertOne) AddUsage1d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddUsage1d(v) + }) +} + +// UpdateUsage1d sets the "usage_1d" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateUsage1d() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateUsage1d() + }) +} + +// SetUsage7d sets the "usage_7d" field. +func (u *APIKeyUpsertOne) SetUsage7d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetUsage7d(v) + }) +} + +// AddUsage7d adds v to the "usage_7d" field. +func (u *APIKeyUpsertOne) AddUsage7d(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddUsage7d(v) + }) +} + +// UpdateUsage7d sets the "usage_7d" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateUsage7d() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateUsage7d() + }) +} + +// SetWindow5hStart sets the "window_5h_start" field. +func (u *APIKeyUpsertOne) SetWindow5hStart(v time.Time) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetWindow5hStart(v) + }) +} + +// UpdateWindow5hStart sets the "window_5h_start" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateWindow5hStart() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateWindow5hStart() + }) +} + +// ClearWindow5hStart clears the value of the "window_5h_start" field. +func (u *APIKeyUpsertOne) ClearWindow5hStart() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.ClearWindow5hStart() + }) +} + +// SetWindow1dStart sets the "window_1d_start" field. +func (u *APIKeyUpsertOne) SetWindow1dStart(v time.Time) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetWindow1dStart(v) + }) +} + +// UpdateWindow1dStart sets the "window_1d_start" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateWindow1dStart() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateWindow1dStart() + }) +} + +// ClearWindow1dStart clears the value of the "window_1d_start" field. +func (u *APIKeyUpsertOne) ClearWindow1dStart() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.ClearWindow1dStart() + }) +} + +// SetWindow7dStart sets the "window_7d_start" field. +func (u *APIKeyUpsertOne) SetWindow7dStart(v time.Time) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetWindow7dStart(v) + }) +} + +// UpdateWindow7dStart sets the "window_7d_start" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateWindow7dStart() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateWindow7dStart() + }) +} + +// ClearWindow7dStart clears the value of the "window_7d_start" field. +func (u *APIKeyUpsertOne) ClearWindow7dStart() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.ClearWindow7dStart() + }) +} + // Exec executes the query. func (u *APIKeyUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1246,6 +1858,27 @@ func (u *APIKeyUpsertBulk) UpdateStatus() *APIKeyUpsertBulk { }) } +// SetLastUsedAt sets the "last_used_at" field. +func (u *APIKeyUpsertBulk) SetLastUsedAt(v time.Time) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetLastUsedAt(v) + }) +} + +// UpdateLastUsedAt sets the "last_used_at" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateLastUsedAt() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateLastUsedAt() + }) +} + +// ClearLastUsedAt clears the value of the "last_used_at" field. +func (u *APIKeyUpsertBulk) ClearLastUsedAt() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.ClearLastUsedAt() + }) +} + // SetIPWhitelist sets the "ip_whitelist" field. func (u *APIKeyUpsertBulk) SetIPWhitelist(v []string) *APIKeyUpsertBulk { return u.Update(func(s *APIKeyUpsert) { @@ -1351,6 +1984,195 @@ func (u *APIKeyUpsertBulk) ClearExpiresAt() *APIKeyUpsertBulk { }) } +// SetRateLimit5h sets the "rate_limit_5h" field. +func (u *APIKeyUpsertBulk) SetRateLimit5h(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetRateLimit5h(v) + }) +} + +// AddRateLimit5h adds v to the "rate_limit_5h" field. +func (u *APIKeyUpsertBulk) AddRateLimit5h(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddRateLimit5h(v) + }) +} + +// UpdateRateLimit5h sets the "rate_limit_5h" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateRateLimit5h() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateRateLimit5h() + }) +} + +// SetRateLimit1d sets the "rate_limit_1d" field. +func (u *APIKeyUpsertBulk) SetRateLimit1d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetRateLimit1d(v) + }) +} + +// AddRateLimit1d adds v to the "rate_limit_1d" field. +func (u *APIKeyUpsertBulk) AddRateLimit1d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddRateLimit1d(v) + }) +} + +// UpdateRateLimit1d sets the "rate_limit_1d" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateRateLimit1d() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateRateLimit1d() + }) +} + +// SetRateLimit7d sets the "rate_limit_7d" field. +func (u *APIKeyUpsertBulk) SetRateLimit7d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetRateLimit7d(v) + }) +} + +// AddRateLimit7d adds v to the "rate_limit_7d" field. +func (u *APIKeyUpsertBulk) AddRateLimit7d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddRateLimit7d(v) + }) +} + +// UpdateRateLimit7d sets the "rate_limit_7d" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateRateLimit7d() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateRateLimit7d() + }) +} + +// SetUsage5h sets the "usage_5h" field. +func (u *APIKeyUpsertBulk) SetUsage5h(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetUsage5h(v) + }) +} + +// AddUsage5h adds v to the "usage_5h" field. +func (u *APIKeyUpsertBulk) AddUsage5h(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddUsage5h(v) + }) +} + +// UpdateUsage5h sets the "usage_5h" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateUsage5h() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateUsage5h() + }) +} + +// SetUsage1d sets the "usage_1d" field. +func (u *APIKeyUpsertBulk) SetUsage1d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetUsage1d(v) + }) +} + +// AddUsage1d adds v to the "usage_1d" field. +func (u *APIKeyUpsertBulk) AddUsage1d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddUsage1d(v) + }) +} + +// UpdateUsage1d sets the "usage_1d" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateUsage1d() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateUsage1d() + }) +} + +// SetUsage7d sets the "usage_7d" field. +func (u *APIKeyUpsertBulk) SetUsage7d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetUsage7d(v) + }) +} + +// AddUsage7d adds v to the "usage_7d" field. +func (u *APIKeyUpsertBulk) AddUsage7d(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddUsage7d(v) + }) +} + +// UpdateUsage7d sets the "usage_7d" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateUsage7d() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateUsage7d() + }) +} + +// SetWindow5hStart sets the "window_5h_start" field. +func (u *APIKeyUpsertBulk) SetWindow5hStart(v time.Time) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetWindow5hStart(v) + }) +} + +// UpdateWindow5hStart sets the "window_5h_start" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateWindow5hStart() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateWindow5hStart() + }) +} + +// ClearWindow5hStart clears the value of the "window_5h_start" field. +func (u *APIKeyUpsertBulk) ClearWindow5hStart() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.ClearWindow5hStart() + }) +} + +// SetWindow1dStart sets the "window_1d_start" field. +func (u *APIKeyUpsertBulk) SetWindow1dStart(v time.Time) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetWindow1dStart(v) + }) +} + +// UpdateWindow1dStart sets the "window_1d_start" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateWindow1dStart() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateWindow1dStart() + }) +} + +// ClearWindow1dStart clears the value of the "window_1d_start" field. +func (u *APIKeyUpsertBulk) ClearWindow1dStart() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.ClearWindow1dStart() + }) +} + +// SetWindow7dStart sets the "window_7d_start" field. +func (u *APIKeyUpsertBulk) SetWindow7dStart(v time.Time) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetWindow7dStart(v) + }) +} + +// UpdateWindow7dStart sets the "window_7d_start" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateWindow7dStart() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateWindow7dStart() + }) +} + +// ClearWindow7dStart clears the value of the "window_7d_start" field. +func (u *APIKeyUpsertBulk) ClearWindow7dStart() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.ClearWindow7dStart() + }) +} + // Exec executes the query. func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/apikey_update.go b/backend/ent/apikey_update.go index b4ff230b..db341e4c 100644 --- a/backend/ent/apikey_update.go +++ b/backend/ent/apikey_update.go @@ -134,6 +134,26 @@ func (_u *APIKeyUpdate) SetNillableStatus(v *string) *APIKeyUpdate { return _u } +// SetLastUsedAt sets the "last_used_at" field. +func (_u *APIKeyUpdate) SetLastUsedAt(v time.Time) *APIKeyUpdate { + _u.mutation.SetLastUsedAt(v) + return _u +} + +// SetNillableLastUsedAt sets the "last_used_at" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableLastUsedAt(v *time.Time) *APIKeyUpdate { + if v != nil { + _u.SetLastUsedAt(*v) + } + return _u +} + +// ClearLastUsedAt clears the value of the "last_used_at" field. +func (_u *APIKeyUpdate) ClearLastUsedAt() *APIKeyUpdate { + _u.mutation.ClearLastUsedAt() + return _u +} + // SetIPWhitelist sets the "ip_whitelist" field. func (_u *APIKeyUpdate) SetIPWhitelist(v []string) *APIKeyUpdate { _u.mutation.SetIPWhitelist(v) @@ -232,6 +252,192 @@ func (_u *APIKeyUpdate) ClearExpiresAt() *APIKeyUpdate { return _u } +// SetRateLimit5h sets the "rate_limit_5h" field. +func (_u *APIKeyUpdate) SetRateLimit5h(v float64) *APIKeyUpdate { + _u.mutation.ResetRateLimit5h() + _u.mutation.SetRateLimit5h(v) + return _u +} + +// SetNillableRateLimit5h sets the "rate_limit_5h" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableRateLimit5h(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetRateLimit5h(*v) + } + return _u +} + +// AddRateLimit5h adds value to the "rate_limit_5h" field. +func (_u *APIKeyUpdate) AddRateLimit5h(v float64) *APIKeyUpdate { + _u.mutation.AddRateLimit5h(v) + return _u +} + +// SetRateLimit1d sets the "rate_limit_1d" field. +func (_u *APIKeyUpdate) SetRateLimit1d(v float64) *APIKeyUpdate { + _u.mutation.ResetRateLimit1d() + _u.mutation.SetRateLimit1d(v) + return _u +} + +// SetNillableRateLimit1d sets the "rate_limit_1d" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableRateLimit1d(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetRateLimit1d(*v) + } + return _u +} + +// AddRateLimit1d adds value to the "rate_limit_1d" field. +func (_u *APIKeyUpdate) AddRateLimit1d(v float64) *APIKeyUpdate { + _u.mutation.AddRateLimit1d(v) + return _u +} + +// SetRateLimit7d sets the "rate_limit_7d" field. +func (_u *APIKeyUpdate) SetRateLimit7d(v float64) *APIKeyUpdate { + _u.mutation.ResetRateLimit7d() + _u.mutation.SetRateLimit7d(v) + return _u +} + +// SetNillableRateLimit7d sets the "rate_limit_7d" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableRateLimit7d(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetRateLimit7d(*v) + } + return _u +} + +// AddRateLimit7d adds value to the "rate_limit_7d" field. +func (_u *APIKeyUpdate) AddRateLimit7d(v float64) *APIKeyUpdate { + _u.mutation.AddRateLimit7d(v) + return _u +} + +// SetUsage5h sets the "usage_5h" field. +func (_u *APIKeyUpdate) SetUsage5h(v float64) *APIKeyUpdate { + _u.mutation.ResetUsage5h() + _u.mutation.SetUsage5h(v) + return _u +} + +// SetNillableUsage5h sets the "usage_5h" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableUsage5h(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetUsage5h(*v) + } + return _u +} + +// AddUsage5h adds value to the "usage_5h" field. +func (_u *APIKeyUpdate) AddUsage5h(v float64) *APIKeyUpdate { + _u.mutation.AddUsage5h(v) + return _u +} + +// SetUsage1d sets the "usage_1d" field. +func (_u *APIKeyUpdate) SetUsage1d(v float64) *APIKeyUpdate { + _u.mutation.ResetUsage1d() + _u.mutation.SetUsage1d(v) + return _u +} + +// SetNillableUsage1d sets the "usage_1d" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableUsage1d(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetUsage1d(*v) + } + return _u +} + +// AddUsage1d adds value to the "usage_1d" field. +func (_u *APIKeyUpdate) AddUsage1d(v float64) *APIKeyUpdate { + _u.mutation.AddUsage1d(v) + return _u +} + +// SetUsage7d sets the "usage_7d" field. +func (_u *APIKeyUpdate) SetUsage7d(v float64) *APIKeyUpdate { + _u.mutation.ResetUsage7d() + _u.mutation.SetUsage7d(v) + return _u +} + +// SetNillableUsage7d sets the "usage_7d" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableUsage7d(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetUsage7d(*v) + } + return _u +} + +// AddUsage7d adds value to the "usage_7d" field. +func (_u *APIKeyUpdate) AddUsage7d(v float64) *APIKeyUpdate { + _u.mutation.AddUsage7d(v) + return _u +} + +// SetWindow5hStart sets the "window_5h_start" field. +func (_u *APIKeyUpdate) SetWindow5hStart(v time.Time) *APIKeyUpdate { + _u.mutation.SetWindow5hStart(v) + return _u +} + +// SetNillableWindow5hStart sets the "window_5h_start" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableWindow5hStart(v *time.Time) *APIKeyUpdate { + if v != nil { + _u.SetWindow5hStart(*v) + } + return _u +} + +// ClearWindow5hStart clears the value of the "window_5h_start" field. +func (_u *APIKeyUpdate) ClearWindow5hStart() *APIKeyUpdate { + _u.mutation.ClearWindow5hStart() + return _u +} + +// SetWindow1dStart sets the "window_1d_start" field. +func (_u *APIKeyUpdate) SetWindow1dStart(v time.Time) *APIKeyUpdate { + _u.mutation.SetWindow1dStart(v) + return _u +} + +// SetNillableWindow1dStart sets the "window_1d_start" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableWindow1dStart(v *time.Time) *APIKeyUpdate { + if v != nil { + _u.SetWindow1dStart(*v) + } + return _u +} + +// ClearWindow1dStart clears the value of the "window_1d_start" field. +func (_u *APIKeyUpdate) ClearWindow1dStart() *APIKeyUpdate { + _u.mutation.ClearWindow1dStart() + return _u +} + +// SetWindow7dStart sets the "window_7d_start" field. +func (_u *APIKeyUpdate) SetWindow7dStart(v time.Time) *APIKeyUpdate { + _u.mutation.SetWindow7dStart(v) + return _u +} + +// SetNillableWindow7dStart sets the "window_7d_start" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableWindow7dStart(v *time.Time) *APIKeyUpdate { + if v != nil { + _u.SetWindow7dStart(*v) + } + return _u +} + +// ClearWindow7dStart clears the value of the "window_7d_start" field. +func (_u *APIKeyUpdate) ClearWindow7dStart() *APIKeyUpdate { + _u.mutation.ClearWindow7dStart() + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate { return _u.SetUserID(v.ID) @@ -390,6 +596,12 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.Status(); ok { _spec.SetField(apikey.FieldStatus, field.TypeString, value) } + if value, ok := _u.mutation.LastUsedAt(); ok { + _spec.SetField(apikey.FieldLastUsedAt, field.TypeTime, value) + } + if _u.mutation.LastUsedAtCleared() { + _spec.ClearField(apikey.FieldLastUsedAt, field.TypeTime) + } if value, ok := _u.mutation.IPWhitelist(); ok { _spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value) } @@ -430,6 +642,60 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.ExpiresAtCleared() { _spec.ClearField(apikey.FieldExpiresAt, field.TypeTime) } + if value, ok := _u.mutation.RateLimit5h(); ok { + _spec.SetField(apikey.FieldRateLimit5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateLimit5h(); ok { + _spec.AddField(apikey.FieldRateLimit5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.RateLimit1d(); ok { + _spec.SetField(apikey.FieldRateLimit1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateLimit1d(); ok { + _spec.AddField(apikey.FieldRateLimit1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.RateLimit7d(); ok { + _spec.SetField(apikey.FieldRateLimit7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateLimit7d(); ok { + _spec.AddField(apikey.FieldRateLimit7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Usage5h(); ok { + _spec.SetField(apikey.FieldUsage5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedUsage5h(); ok { + _spec.AddField(apikey.FieldUsage5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Usage1d(); ok { + _spec.SetField(apikey.FieldUsage1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedUsage1d(); ok { + _spec.AddField(apikey.FieldUsage1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Usage7d(); ok { + _spec.SetField(apikey.FieldUsage7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedUsage7d(); ok { + _spec.AddField(apikey.FieldUsage7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Window5hStart(); ok { + _spec.SetField(apikey.FieldWindow5hStart, field.TypeTime, value) + } + if _u.mutation.Window5hStartCleared() { + _spec.ClearField(apikey.FieldWindow5hStart, field.TypeTime) + } + if value, ok := _u.mutation.Window1dStart(); ok { + _spec.SetField(apikey.FieldWindow1dStart, field.TypeTime, value) + } + if _u.mutation.Window1dStartCleared() { + _spec.ClearField(apikey.FieldWindow1dStart, field.TypeTime) + } + if value, ok := _u.mutation.Window7dStart(); ok { + _spec.SetField(apikey.FieldWindow7dStart, field.TypeTime, value) + } + if _u.mutation.Window7dStartCleared() { + _spec.ClearField(apikey.FieldWindow7dStart, field.TypeTime) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -655,6 +921,26 @@ func (_u *APIKeyUpdateOne) SetNillableStatus(v *string) *APIKeyUpdateOne { return _u } +// SetLastUsedAt sets the "last_used_at" field. +func (_u *APIKeyUpdateOne) SetLastUsedAt(v time.Time) *APIKeyUpdateOne { + _u.mutation.SetLastUsedAt(v) + return _u +} + +// SetNillableLastUsedAt sets the "last_used_at" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableLastUsedAt(v *time.Time) *APIKeyUpdateOne { + if v != nil { + _u.SetLastUsedAt(*v) + } + return _u +} + +// ClearLastUsedAt clears the value of the "last_used_at" field. +func (_u *APIKeyUpdateOne) ClearLastUsedAt() *APIKeyUpdateOne { + _u.mutation.ClearLastUsedAt() + return _u +} + // SetIPWhitelist sets the "ip_whitelist" field. func (_u *APIKeyUpdateOne) SetIPWhitelist(v []string) *APIKeyUpdateOne { _u.mutation.SetIPWhitelist(v) @@ -753,6 +1039,192 @@ func (_u *APIKeyUpdateOne) ClearExpiresAt() *APIKeyUpdateOne { return _u } +// SetRateLimit5h sets the "rate_limit_5h" field. +func (_u *APIKeyUpdateOne) SetRateLimit5h(v float64) *APIKeyUpdateOne { + _u.mutation.ResetRateLimit5h() + _u.mutation.SetRateLimit5h(v) + return _u +} + +// SetNillableRateLimit5h sets the "rate_limit_5h" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableRateLimit5h(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetRateLimit5h(*v) + } + return _u +} + +// AddRateLimit5h adds value to the "rate_limit_5h" field. +func (_u *APIKeyUpdateOne) AddRateLimit5h(v float64) *APIKeyUpdateOne { + _u.mutation.AddRateLimit5h(v) + return _u +} + +// SetRateLimit1d sets the "rate_limit_1d" field. +func (_u *APIKeyUpdateOne) SetRateLimit1d(v float64) *APIKeyUpdateOne { + _u.mutation.ResetRateLimit1d() + _u.mutation.SetRateLimit1d(v) + return _u +} + +// SetNillableRateLimit1d sets the "rate_limit_1d" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableRateLimit1d(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetRateLimit1d(*v) + } + return _u +} + +// AddRateLimit1d adds value to the "rate_limit_1d" field. +func (_u *APIKeyUpdateOne) AddRateLimit1d(v float64) *APIKeyUpdateOne { + _u.mutation.AddRateLimit1d(v) + return _u +} + +// SetRateLimit7d sets the "rate_limit_7d" field. +func (_u *APIKeyUpdateOne) SetRateLimit7d(v float64) *APIKeyUpdateOne { + _u.mutation.ResetRateLimit7d() + _u.mutation.SetRateLimit7d(v) + return _u +} + +// SetNillableRateLimit7d sets the "rate_limit_7d" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableRateLimit7d(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetRateLimit7d(*v) + } + return _u +} + +// AddRateLimit7d adds value to the "rate_limit_7d" field. +func (_u *APIKeyUpdateOne) AddRateLimit7d(v float64) *APIKeyUpdateOne { + _u.mutation.AddRateLimit7d(v) + return _u +} + +// SetUsage5h sets the "usage_5h" field. +func (_u *APIKeyUpdateOne) SetUsage5h(v float64) *APIKeyUpdateOne { + _u.mutation.ResetUsage5h() + _u.mutation.SetUsage5h(v) + return _u +} + +// SetNillableUsage5h sets the "usage_5h" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableUsage5h(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetUsage5h(*v) + } + return _u +} + +// AddUsage5h adds value to the "usage_5h" field. +func (_u *APIKeyUpdateOne) AddUsage5h(v float64) *APIKeyUpdateOne { + _u.mutation.AddUsage5h(v) + return _u +} + +// SetUsage1d sets the "usage_1d" field. +func (_u *APIKeyUpdateOne) SetUsage1d(v float64) *APIKeyUpdateOne { + _u.mutation.ResetUsage1d() + _u.mutation.SetUsage1d(v) + return _u +} + +// SetNillableUsage1d sets the "usage_1d" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableUsage1d(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetUsage1d(*v) + } + return _u +} + +// AddUsage1d adds value to the "usage_1d" field. +func (_u *APIKeyUpdateOne) AddUsage1d(v float64) *APIKeyUpdateOne { + _u.mutation.AddUsage1d(v) + return _u +} + +// SetUsage7d sets the "usage_7d" field. +func (_u *APIKeyUpdateOne) SetUsage7d(v float64) *APIKeyUpdateOne { + _u.mutation.ResetUsage7d() + _u.mutation.SetUsage7d(v) + return _u +} + +// SetNillableUsage7d sets the "usage_7d" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableUsage7d(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetUsage7d(*v) + } + return _u +} + +// AddUsage7d adds value to the "usage_7d" field. +func (_u *APIKeyUpdateOne) AddUsage7d(v float64) *APIKeyUpdateOne { + _u.mutation.AddUsage7d(v) + return _u +} + +// SetWindow5hStart sets the "window_5h_start" field. +func (_u *APIKeyUpdateOne) SetWindow5hStart(v time.Time) *APIKeyUpdateOne { + _u.mutation.SetWindow5hStart(v) + return _u +} + +// SetNillableWindow5hStart sets the "window_5h_start" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableWindow5hStart(v *time.Time) *APIKeyUpdateOne { + if v != nil { + _u.SetWindow5hStart(*v) + } + return _u +} + +// ClearWindow5hStart clears the value of the "window_5h_start" field. +func (_u *APIKeyUpdateOne) ClearWindow5hStart() *APIKeyUpdateOne { + _u.mutation.ClearWindow5hStart() + return _u +} + +// SetWindow1dStart sets the "window_1d_start" field. +func (_u *APIKeyUpdateOne) SetWindow1dStart(v time.Time) *APIKeyUpdateOne { + _u.mutation.SetWindow1dStart(v) + return _u +} + +// SetNillableWindow1dStart sets the "window_1d_start" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableWindow1dStart(v *time.Time) *APIKeyUpdateOne { + if v != nil { + _u.SetWindow1dStart(*v) + } + return _u +} + +// ClearWindow1dStart clears the value of the "window_1d_start" field. +func (_u *APIKeyUpdateOne) ClearWindow1dStart() *APIKeyUpdateOne { + _u.mutation.ClearWindow1dStart() + return _u +} + +// SetWindow7dStart sets the "window_7d_start" field. +func (_u *APIKeyUpdateOne) SetWindow7dStart(v time.Time) *APIKeyUpdateOne { + _u.mutation.SetWindow7dStart(v) + return _u +} + +// SetNillableWindow7dStart sets the "window_7d_start" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableWindow7dStart(v *time.Time) *APIKeyUpdateOne { + if v != nil { + _u.SetWindow7dStart(*v) + } + return _u +} + +// ClearWindow7dStart clears the value of the "window_7d_start" field. +func (_u *APIKeyUpdateOne) ClearWindow7dStart() *APIKeyUpdateOne { + _u.mutation.ClearWindow7dStart() + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne { return _u.SetUserID(v.ID) @@ -941,6 +1413,12 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro if value, ok := _u.mutation.Status(); ok { _spec.SetField(apikey.FieldStatus, field.TypeString, value) } + if value, ok := _u.mutation.LastUsedAt(); ok { + _spec.SetField(apikey.FieldLastUsedAt, field.TypeTime, value) + } + if _u.mutation.LastUsedAtCleared() { + _spec.ClearField(apikey.FieldLastUsedAt, field.TypeTime) + } if value, ok := _u.mutation.IPWhitelist(); ok { _spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value) } @@ -981,6 +1459,60 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro if _u.mutation.ExpiresAtCleared() { _spec.ClearField(apikey.FieldExpiresAt, field.TypeTime) } + if value, ok := _u.mutation.RateLimit5h(); ok { + _spec.SetField(apikey.FieldRateLimit5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateLimit5h(); ok { + _spec.AddField(apikey.FieldRateLimit5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.RateLimit1d(); ok { + _spec.SetField(apikey.FieldRateLimit1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateLimit1d(); ok { + _spec.AddField(apikey.FieldRateLimit1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.RateLimit7d(); ok { + _spec.SetField(apikey.FieldRateLimit7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedRateLimit7d(); ok { + _spec.AddField(apikey.FieldRateLimit7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Usage5h(); ok { + _spec.SetField(apikey.FieldUsage5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedUsage5h(); ok { + _spec.AddField(apikey.FieldUsage5h, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Usage1d(); ok { + _spec.SetField(apikey.FieldUsage1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedUsage1d(); ok { + _spec.AddField(apikey.FieldUsage1d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Usage7d(); ok { + _spec.SetField(apikey.FieldUsage7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedUsage7d(); ok { + _spec.AddField(apikey.FieldUsage7d, field.TypeFloat64, value) + } + if value, ok := _u.mutation.Window5hStart(); ok { + _spec.SetField(apikey.FieldWindow5hStart, field.TypeTime, value) + } + if _u.mutation.Window5hStartCleared() { + _spec.ClearField(apikey.FieldWindow5hStart, field.TypeTime) + } + if value, ok := _u.mutation.Window1dStart(); ok { + _spec.SetField(apikey.FieldWindow1dStart, field.TypeTime, value) + } + if _u.mutation.Window1dStartCleared() { + _spec.ClearField(apikey.FieldWindow1dStart, field.TypeTime) + } + if value, ok := _u.mutation.Window7dStart(); ok { + _spec.SetField(apikey.FieldWindow7dStart, field.TypeTime, value) + } + if _u.mutation.Window7dStartCleared() { + _spec.ClearField(apikey.FieldWindow7dStart, field.TypeTime) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, diff --git a/backend/ent/client.go b/backend/ent/client.go index a791c081..7ebbaa32 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -22,10 +22,12 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/setting" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" @@ -57,6 +59,8 @@ type Client struct { ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient + // IdempotencyRecord is the client for interacting with the IdempotencyRecord builders. + IdempotencyRecord *IdempotencyRecordClient // PromoCode is the client for interacting with the PromoCode builders. PromoCode *PromoCodeClient // PromoCodeUsage is the client for interacting with the PromoCodeUsage builders. @@ -65,6 +69,8 @@ type Client struct { Proxy *ProxyClient // RedeemCode is the client for interacting with the RedeemCode builders. RedeemCode *RedeemCodeClient + // SecuritySecret is the client for interacting with the SecuritySecret builders. + SecuritySecret *SecuritySecretClient // Setting is the client for interacting with the Setting builders. Setting *SettingClient // UsageCleanupTask is the client for interacting with the UsageCleanupTask builders. @@ -99,10 +105,12 @@ func (c *Client) init() { c.AnnouncementRead = NewAnnouncementReadClient(c.config) c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config) c.Group = NewGroupClient(c.config) + c.IdempotencyRecord = NewIdempotencyRecordClient(c.config) c.PromoCode = NewPromoCodeClient(c.config) c.PromoCodeUsage = NewPromoCodeUsageClient(c.config) c.Proxy = NewProxyClient(c.config) c.RedeemCode = NewRedeemCodeClient(c.config) + c.SecuritySecret = NewSecuritySecretClient(c.config) c.Setting = NewSettingClient(c.config) c.UsageCleanupTask = NewUsageCleanupTaskClient(c.config) c.UsageLog = NewUsageLogClient(c.config) @@ -210,10 +218,12 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { AnnouncementRead: NewAnnouncementReadClient(cfg), ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), Group: NewGroupClient(cfg), + IdempotencyRecord: NewIdempotencyRecordClient(cfg), PromoCode: NewPromoCodeClient(cfg), PromoCodeUsage: NewPromoCodeUsageClient(cfg), Proxy: NewProxyClient(cfg), RedeemCode: NewRedeemCodeClient(cfg), + SecuritySecret: NewSecuritySecretClient(cfg), Setting: NewSettingClient(cfg), UsageCleanupTask: NewUsageCleanupTaskClient(cfg), UsageLog: NewUsageLogClient(cfg), @@ -248,10 +258,12 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) AnnouncementRead: NewAnnouncementReadClient(cfg), ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), Group: NewGroupClient(cfg), + IdempotencyRecord: NewIdempotencyRecordClient(cfg), PromoCode: NewPromoCodeClient(cfg), PromoCodeUsage: NewPromoCodeUsageClient(cfg), Proxy: NewProxyClient(cfg), RedeemCode: NewRedeemCodeClient(cfg), + SecuritySecret: NewSecuritySecretClient(cfg), Setting: NewSettingClient(cfg), UsageCleanupTask: NewUsageCleanupTaskClient(cfg), UsageLog: NewUsageLogClient(cfg), @@ -290,10 +302,10 @@ func (c *Client) Close() error { func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, - c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User, - c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, - c.UserSubscription, + c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PromoCode, + c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, + c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, + c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Use(hooks...) } @@ -304,10 +316,10 @@ func (c *Client) Use(hooks ...Hook) { func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, - c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User, - c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, - c.UserSubscription, + c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PromoCode, + c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, + c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, + c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Intercept(interceptors...) } @@ -330,6 +342,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.ErrorPassthroughRule.mutate(ctx, m) case *GroupMutation: return c.Group.mutate(ctx, m) + case *IdempotencyRecordMutation: + return c.IdempotencyRecord.mutate(ctx, m) case *PromoCodeMutation: return c.PromoCode.mutate(ctx, m) case *PromoCodeUsageMutation: @@ -338,6 +352,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.Proxy.mutate(ctx, m) case *RedeemCodeMutation: return c.RedeemCode.mutate(ctx, m) + case *SecuritySecretMutation: + return c.SecuritySecret.mutate(ctx, m) case *SettingMutation: return c.Setting.mutate(ctx, m) case *UsageCleanupTaskMutation: @@ -1567,6 +1583,139 @@ func (c *GroupClient) mutate(ctx context.Context, m *GroupMutation) (Value, erro } } +// IdempotencyRecordClient is a client for the IdempotencyRecord schema. +type IdempotencyRecordClient struct { + config +} + +// NewIdempotencyRecordClient returns a client for the IdempotencyRecord from the given config. +func NewIdempotencyRecordClient(c config) *IdempotencyRecordClient { + return &IdempotencyRecordClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `idempotencyrecord.Hooks(f(g(h())))`. +func (c *IdempotencyRecordClient) Use(hooks ...Hook) { + c.hooks.IdempotencyRecord = append(c.hooks.IdempotencyRecord, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `idempotencyrecord.Intercept(f(g(h())))`. +func (c *IdempotencyRecordClient) Intercept(interceptors ...Interceptor) { + c.inters.IdempotencyRecord = append(c.inters.IdempotencyRecord, interceptors...) +} + +// Create returns a builder for creating a IdempotencyRecord entity. +func (c *IdempotencyRecordClient) Create() *IdempotencyRecordCreate { + mutation := newIdempotencyRecordMutation(c.config, OpCreate) + return &IdempotencyRecordCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of IdempotencyRecord entities. +func (c *IdempotencyRecordClient) CreateBulk(builders ...*IdempotencyRecordCreate) *IdempotencyRecordCreateBulk { + return &IdempotencyRecordCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *IdempotencyRecordClient) MapCreateBulk(slice any, setFunc func(*IdempotencyRecordCreate, int)) *IdempotencyRecordCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &IdempotencyRecordCreateBulk{err: fmt.Errorf("calling to IdempotencyRecordClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*IdempotencyRecordCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &IdempotencyRecordCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for IdempotencyRecord. +func (c *IdempotencyRecordClient) Update() *IdempotencyRecordUpdate { + mutation := newIdempotencyRecordMutation(c.config, OpUpdate) + return &IdempotencyRecordUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *IdempotencyRecordClient) UpdateOne(_m *IdempotencyRecord) *IdempotencyRecordUpdateOne { + mutation := newIdempotencyRecordMutation(c.config, OpUpdateOne, withIdempotencyRecord(_m)) + return &IdempotencyRecordUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *IdempotencyRecordClient) UpdateOneID(id int64) *IdempotencyRecordUpdateOne { + mutation := newIdempotencyRecordMutation(c.config, OpUpdateOne, withIdempotencyRecordID(id)) + return &IdempotencyRecordUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for IdempotencyRecord. +func (c *IdempotencyRecordClient) Delete() *IdempotencyRecordDelete { + mutation := newIdempotencyRecordMutation(c.config, OpDelete) + return &IdempotencyRecordDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *IdempotencyRecordClient) DeleteOne(_m *IdempotencyRecord) *IdempotencyRecordDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *IdempotencyRecordClient) DeleteOneID(id int64) *IdempotencyRecordDeleteOne { + builder := c.Delete().Where(idempotencyrecord.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &IdempotencyRecordDeleteOne{builder} +} + +// Query returns a query builder for IdempotencyRecord. +func (c *IdempotencyRecordClient) Query() *IdempotencyRecordQuery { + return &IdempotencyRecordQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeIdempotencyRecord}, + inters: c.Interceptors(), + } +} + +// Get returns a IdempotencyRecord entity by its id. +func (c *IdempotencyRecordClient) Get(ctx context.Context, id int64) (*IdempotencyRecord, error) { + return c.Query().Where(idempotencyrecord.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *IdempotencyRecordClient) GetX(ctx context.Context, id int64) *IdempotencyRecord { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *IdempotencyRecordClient) Hooks() []Hook { + return c.hooks.IdempotencyRecord +} + +// Interceptors returns the client interceptors. +func (c *IdempotencyRecordClient) Interceptors() []Interceptor { + return c.inters.IdempotencyRecord +} + +func (c *IdempotencyRecordClient) mutate(ctx context.Context, m *IdempotencyRecordMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&IdempotencyRecordCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&IdempotencyRecordUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&IdempotencyRecordUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&IdempotencyRecordDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown IdempotencyRecord mutation op: %q", m.Op()) + } +} + // PromoCodeClient is a client for the PromoCode schema. type PromoCodeClient struct { config @@ -2197,6 +2346,139 @@ func (c *RedeemCodeClient) mutate(ctx context.Context, m *RedeemCodeMutation) (V } } +// SecuritySecretClient is a client for the SecuritySecret schema. +type SecuritySecretClient struct { + config +} + +// NewSecuritySecretClient returns a client for the SecuritySecret from the given config. +func NewSecuritySecretClient(c config) *SecuritySecretClient { + return &SecuritySecretClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `securitysecret.Hooks(f(g(h())))`. +func (c *SecuritySecretClient) Use(hooks ...Hook) { + c.hooks.SecuritySecret = append(c.hooks.SecuritySecret, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `securitysecret.Intercept(f(g(h())))`. +func (c *SecuritySecretClient) Intercept(interceptors ...Interceptor) { + c.inters.SecuritySecret = append(c.inters.SecuritySecret, interceptors...) +} + +// Create returns a builder for creating a SecuritySecret entity. +func (c *SecuritySecretClient) Create() *SecuritySecretCreate { + mutation := newSecuritySecretMutation(c.config, OpCreate) + return &SecuritySecretCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of SecuritySecret entities. +func (c *SecuritySecretClient) CreateBulk(builders ...*SecuritySecretCreate) *SecuritySecretCreateBulk { + return &SecuritySecretCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *SecuritySecretClient) MapCreateBulk(slice any, setFunc func(*SecuritySecretCreate, int)) *SecuritySecretCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &SecuritySecretCreateBulk{err: fmt.Errorf("calling to SecuritySecretClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*SecuritySecretCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &SecuritySecretCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for SecuritySecret. +func (c *SecuritySecretClient) Update() *SecuritySecretUpdate { + mutation := newSecuritySecretMutation(c.config, OpUpdate) + return &SecuritySecretUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *SecuritySecretClient) UpdateOne(_m *SecuritySecret) *SecuritySecretUpdateOne { + mutation := newSecuritySecretMutation(c.config, OpUpdateOne, withSecuritySecret(_m)) + return &SecuritySecretUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *SecuritySecretClient) UpdateOneID(id int64) *SecuritySecretUpdateOne { + mutation := newSecuritySecretMutation(c.config, OpUpdateOne, withSecuritySecretID(id)) + return &SecuritySecretUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for SecuritySecret. +func (c *SecuritySecretClient) Delete() *SecuritySecretDelete { + mutation := newSecuritySecretMutation(c.config, OpDelete) + return &SecuritySecretDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *SecuritySecretClient) DeleteOne(_m *SecuritySecret) *SecuritySecretDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *SecuritySecretClient) DeleteOneID(id int64) *SecuritySecretDeleteOne { + builder := c.Delete().Where(securitysecret.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &SecuritySecretDeleteOne{builder} +} + +// Query returns a query builder for SecuritySecret. +func (c *SecuritySecretClient) Query() *SecuritySecretQuery { + return &SecuritySecretQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeSecuritySecret}, + inters: c.Interceptors(), + } +} + +// Get returns a SecuritySecret entity by its id. +func (c *SecuritySecretClient) Get(ctx context.Context, id int64) (*SecuritySecret, error) { + return c.Query().Where(securitysecret.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *SecuritySecretClient) GetX(ctx context.Context, id int64) *SecuritySecret { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *SecuritySecretClient) Hooks() []Hook { + return c.hooks.SecuritySecret +} + +// Interceptors returns the client interceptors. +func (c *SecuritySecretClient) Interceptors() []Interceptor { + return c.inters.SecuritySecret +} + +func (c *SecuritySecretClient) mutate(ctx context.Context, m *SecuritySecretMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&SecuritySecretCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&SecuritySecretUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&SecuritySecretUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&SecuritySecretDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown SecuritySecret mutation op: %q", m.Op()) + } +} + // SettingClient is a client for the Setting schema. type SettingClient struct { config @@ -3606,15 +3888,17 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription type ( hooks struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, - UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook + ErrorPassthroughRule, Group, IdempotencyRecord, PromoCode, PromoCodeUsage, + Proxy, RedeemCode, SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, + UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, + UserSubscription []ent.Hook } inters struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, - UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor + ErrorPassthroughRule, Group, IdempotencyRecord, PromoCode, PromoCodeUsage, + Proxy, RedeemCode, SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, + UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, + UserSubscription []ent.Interceptor } ) diff --git a/backend/ent/ent.go b/backend/ent/ent.go index 5767a167..5197e4d8 100644 --- a/backend/ent/ent.go +++ b/backend/ent/ent.go @@ -19,10 +19,12 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/setting" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" @@ -98,10 +100,12 @@ func checkColumn(t, c string) error { announcementread.Table: announcementread.ValidColumn, errorpassthroughrule.Table: errorpassthroughrule.ValidColumn, group.Table: group.ValidColumn, + idempotencyrecord.Table: idempotencyrecord.ValidColumn, promocode.Table: promocode.ValidColumn, promocodeusage.Table: promocodeusage.ValidColumn, proxy.Table: proxy.ValidColumn, redeemcode.Table: redeemcode.ValidColumn, + securitysecret.Table: securitysecret.ValidColumn, setting.Table: setting.ValidColumn, usagecleanuptask.Table: usagecleanuptask.ValidColumn, usagelog.Table: usagelog.ValidColumn, diff --git a/backend/ent/group.go b/backend/ent/group.go index 3c8d68b5..76c3cae2 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -52,6 +52,16 @@ type Group struct { ImagePrice2k *float64 `json:"image_price_2k,omitempty"` // ImagePrice4k holds the value of the "image_price_4k" field. ImagePrice4k *float64 `json:"image_price_4k,omitempty"` + // SoraImagePrice360 holds the value of the "sora_image_price_360" field. + SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"` + // SoraImagePrice540 holds the value of the "sora_image_price_540" field. + SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"` + // SoraVideoPricePerRequest holds the value of the "sora_video_price_per_request" field. + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` + // SoraVideoPricePerRequestHd holds the value of the "sora_video_price_per_request_hd" field. + SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"` + // SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field. + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"` // 是否仅允许 Claude Code 客户端 ClaudeCodeOnly bool `json:"claude_code_only,omitempty"` // 非 Claude Code 请求降级使用的分组 ID @@ -178,9 +188,9 @@ func (*Group) scanValues(columns []string) ([]any, error) { values[i] = new([]byte) case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject: values[i] = new(sql.NullBool) - case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: + case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd: values[i] = new(sql.NullFloat64) - case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder: + case group.FieldID, group.FieldDefaultValidityDays, group.FieldSoraStorageQuotaBytes, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder: values[i] = new(sql.NullInt64) case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType: values[i] = new(sql.NullString) @@ -317,6 +327,40 @@ func (_m *Group) assignValues(columns []string, values []any) error { _m.ImagePrice4k = new(float64) *_m.ImagePrice4k = value.Float64 } + case group.FieldSoraImagePrice360: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field sora_image_price_360", values[i]) + } else if value.Valid { + _m.SoraImagePrice360 = new(float64) + *_m.SoraImagePrice360 = value.Float64 + } + case group.FieldSoraImagePrice540: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field sora_image_price_540", values[i]) + } else if value.Valid { + _m.SoraImagePrice540 = new(float64) + *_m.SoraImagePrice540 = value.Float64 + } + case group.FieldSoraVideoPricePerRequest: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field sora_video_price_per_request", values[i]) + } else if value.Valid { + _m.SoraVideoPricePerRequest = new(float64) + *_m.SoraVideoPricePerRequest = value.Float64 + } + case group.FieldSoraVideoPricePerRequestHd: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field sora_video_price_per_request_hd", values[i]) + } else if value.Valid { + _m.SoraVideoPricePerRequestHd = new(float64) + *_m.SoraVideoPricePerRequestHd = value.Float64 + } + case group.FieldSoraStorageQuotaBytes: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i]) + } else if value.Valid { + _m.SoraStorageQuotaBytes = value.Int64 + } case group.FieldClaudeCodeOnly: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field claude_code_only", values[i]) @@ -514,6 +558,29 @@ func (_m *Group) String() string { builder.WriteString(fmt.Sprintf("%v", *v)) } builder.WriteString(", ") + if v := _m.SoraImagePrice360; v != nil { + builder.WriteString("sora_image_price_360=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.SoraImagePrice540; v != nil { + builder.WriteString("sora_image_price_540=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.SoraVideoPricePerRequest; v != nil { + builder.WriteString("sora_video_price_per_request=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.SoraVideoPricePerRequestHd; v != nil { + builder.WriteString("sora_video_price_per_request_hd=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("sora_storage_quota_bytes=") + builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes)) + builder.WriteString(", ") builder.WriteString("claude_code_only=") builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly)) builder.WriteString(", ") diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 31c67756..6ac4eea1 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -49,6 +49,16 @@ const ( FieldImagePrice2k = "image_price_2k" // FieldImagePrice4k holds the string denoting the image_price_4k field in the database. FieldImagePrice4k = "image_price_4k" + // FieldSoraImagePrice360 holds the string denoting the sora_image_price_360 field in the database. + FieldSoraImagePrice360 = "sora_image_price_360" + // FieldSoraImagePrice540 holds the string denoting the sora_image_price_540 field in the database. + FieldSoraImagePrice540 = "sora_image_price_540" + // FieldSoraVideoPricePerRequest holds the string denoting the sora_video_price_per_request field in the database. + FieldSoraVideoPricePerRequest = "sora_video_price_per_request" + // FieldSoraVideoPricePerRequestHd holds the string denoting the sora_video_price_per_request_hd field in the database. + FieldSoraVideoPricePerRequestHd = "sora_video_price_per_request_hd" + // FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database. + FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes" // FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database. FieldClaudeCodeOnly = "claude_code_only" // FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database. @@ -157,6 +167,11 @@ var Columns = []string{ FieldImagePrice1k, FieldImagePrice2k, FieldImagePrice4k, + FieldSoraImagePrice360, + FieldSoraImagePrice540, + FieldSoraVideoPricePerRequest, + FieldSoraVideoPricePerRequestHd, + FieldSoraStorageQuotaBytes, FieldClaudeCodeOnly, FieldFallbackGroupID, FieldFallbackGroupIDOnInvalidRequest, @@ -220,6 +235,8 @@ var ( SubscriptionTypeValidator func(string) error // DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field. DefaultDefaultValidityDays int + // DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field. + DefaultSoraStorageQuotaBytes int64 // DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field. DefaultClaudeCodeOnly bool // DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field. @@ -325,6 +342,31 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc() } +// BySoraImagePrice360 orders the results by the sora_image_price_360 field. +func BySoraImagePrice360(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraImagePrice360, opts...).ToFunc() +} + +// BySoraImagePrice540 orders the results by the sora_image_price_540 field. +func BySoraImagePrice540(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraImagePrice540, opts...).ToFunc() +} + +// BySoraVideoPricePerRequest orders the results by the sora_video_price_per_request field. +func BySoraVideoPricePerRequest(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraVideoPricePerRequest, opts...).ToFunc() +} + +// BySoraVideoPricePerRequestHd orders the results by the sora_video_price_per_request_hd field. +func BySoraVideoPricePerRequestHd(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraVideoPricePerRequestHd, opts...).ToFunc() +} + +// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field. +func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc() +} + // ByClaudeCodeOnly orders the results by the claude_code_only field. func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc() diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index cd5197c9..4cf65d0f 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -140,6 +140,31 @@ func ImagePrice4k(v float64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v)) } +// SoraImagePrice360 applies equality check predicate on the "sora_image_price_360" field. It's identical to SoraImagePrice360EQ. +func SoraImagePrice360(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice540 applies equality check predicate on the "sora_image_price_540" field. It's identical to SoraImagePrice540EQ. +func SoraImagePrice540(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v)) +} + +// SoraVideoPricePerRequest applies equality check predicate on the "sora_video_price_per_request" field. It's identical to SoraVideoPricePerRequestEQ. +func SoraVideoPricePerRequest(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestHd applies equality check predicate on the "sora_video_price_per_request_hd" field. It's identical to SoraVideoPricePerRequestHdEQ. +func SoraVideoPricePerRequestHd(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ. +func SoraStorageQuotaBytes(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) +} + // ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ. func ClaudeCodeOnly(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) @@ -1025,6 +1050,246 @@ func ImagePrice4kNotNil() predicate.Group { return predicate.Group(sql.FieldNotNull(FieldImagePrice4k)) } +// SoraImagePrice360EQ applies the EQ predicate on the "sora_image_price_360" field. +func SoraImagePrice360EQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360NEQ applies the NEQ predicate on the "sora_image_price_360" field. +func SoraImagePrice360NEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360In applies the In predicate on the "sora_image_price_360" field. +func SoraImagePrice360In(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraImagePrice360, vs...)) +} + +// SoraImagePrice360NotIn applies the NotIn predicate on the "sora_image_price_360" field. +func SoraImagePrice360NotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice360, vs...)) +} + +// SoraImagePrice360GT applies the GT predicate on the "sora_image_price_360" field. +func SoraImagePrice360GT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360GTE applies the GTE predicate on the "sora_image_price_360" field. +func SoraImagePrice360GTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360LT applies the LT predicate on the "sora_image_price_360" field. +func SoraImagePrice360LT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360LTE applies the LTE predicate on the "sora_image_price_360" field. +func SoraImagePrice360LTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360IsNil applies the IsNil predicate on the "sora_image_price_360" field. +func SoraImagePrice360IsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice360)) +} + +// SoraImagePrice360NotNil applies the NotNil predicate on the "sora_image_price_360" field. +func SoraImagePrice360NotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice360)) +} + +// SoraImagePrice540EQ applies the EQ predicate on the "sora_image_price_540" field. +func SoraImagePrice540EQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540NEQ applies the NEQ predicate on the "sora_image_price_540" field. +func SoraImagePrice540NEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540In applies the In predicate on the "sora_image_price_540" field. +func SoraImagePrice540In(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraImagePrice540, vs...)) +} + +// SoraImagePrice540NotIn applies the NotIn predicate on the "sora_image_price_540" field. +func SoraImagePrice540NotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice540, vs...)) +} + +// SoraImagePrice540GT applies the GT predicate on the "sora_image_price_540" field. +func SoraImagePrice540GT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540GTE applies the GTE predicate on the "sora_image_price_540" field. +func SoraImagePrice540GTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540LT applies the LT predicate on the "sora_image_price_540" field. +func SoraImagePrice540LT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540LTE applies the LTE predicate on the "sora_image_price_540" field. +func SoraImagePrice540LTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540IsNil applies the IsNil predicate on the "sora_image_price_540" field. +func SoraImagePrice540IsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice540)) +} + +// SoraImagePrice540NotNil applies the NotNil predicate on the "sora_image_price_540" field. +func SoraImagePrice540NotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice540)) +} + +// SoraVideoPricePerRequestEQ applies the EQ predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestNEQ applies the NEQ predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestNEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestIn applies the In predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequest, vs...)) +} + +// SoraVideoPricePerRequestNotIn applies the NotIn predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestNotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequest, vs...)) +} + +// SoraVideoPricePerRequestGT applies the GT predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestGT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestGTE applies the GTE predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestGTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestLT applies the LT predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestLT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestLTE applies the LTE predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestLTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestIsNil applies the IsNil predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequest)) +} + +// SoraVideoPricePerRequestNotNil applies the NotNil predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequest)) +} + +// SoraVideoPricePerRequestHdEQ applies the EQ predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdNEQ applies the NEQ predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdNEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdIn applies the In predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequestHd, vs...)) +} + +// SoraVideoPricePerRequestHdNotIn applies the NotIn predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdNotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequestHd, vs...)) +} + +// SoraVideoPricePerRequestHdGT applies the GT predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdGT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdGTE applies the GTE predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdGTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdLT applies the LT predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdLT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdLTE applies the LTE predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdLTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdIsNil applies the IsNil predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequestHd)) +} + +// SoraVideoPricePerRequestHdNotNil applies the NotNil predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequestHd)) +} + +// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesNEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...)) +} + +// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...)) +} + +// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesGT(v int64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesGTE(v int64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesLT(v int64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesLTE(v int64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraStorageQuotaBytes, v)) +} + // ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field. func ClaudeCodeOnlyEQ(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 707600a7..0ce5f959 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -258,6 +258,76 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate { return _c } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (_c *GroupCreate) SetSoraImagePrice360(v float64) *GroupCreate { + _c.mutation.SetSoraImagePrice360(v) + return _c +} + +// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraImagePrice360(v *float64) *GroupCreate { + if v != nil { + _c.SetSoraImagePrice360(*v) + } + return _c +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (_c *GroupCreate) SetSoraImagePrice540(v float64) *GroupCreate { + _c.mutation.SetSoraImagePrice540(v) + return _c +} + +// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraImagePrice540(v *float64) *GroupCreate { + if v != nil { + _c.SetSoraImagePrice540(*v) + } + return _c +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (_c *GroupCreate) SetSoraVideoPricePerRequest(v float64) *GroupCreate { + _c.mutation.SetSoraVideoPricePerRequest(v) + return _c +} + +// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupCreate { + if v != nil { + _c.SetSoraVideoPricePerRequest(*v) + } + return _c +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (_c *GroupCreate) SetSoraVideoPricePerRequestHd(v float64) *GroupCreate { + _c.mutation.SetSoraVideoPricePerRequestHd(v) + return _c +} + +// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupCreate { + if v != nil { + _c.SetSoraVideoPricePerRequestHd(*v) + } + return _c +} + +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_c *GroupCreate) SetSoraStorageQuotaBytes(v int64) *GroupCreate { + _c.mutation.SetSoraStorageQuotaBytes(v) + return _c +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupCreate { + if v != nil { + _c.SetSoraStorageQuotaBytes(*v) + } + return _c +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate { _c.mutation.SetClaudeCodeOnly(v) @@ -519,6 +589,10 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultDefaultValidityDays _c.mutation.SetDefaultValidityDays(v) } + if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { + v := group.DefaultSoraStorageQuotaBytes + _c.mutation.SetSoraStorageQuotaBytes(v) + } if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { v := group.DefaultClaudeCodeOnly _c.mutation.SetClaudeCodeOnly(v) @@ -591,6 +665,9 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.DefaultValidityDays(); !ok { return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)} } + if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { + return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "Group.sora_storage_quota_bytes"`)} + } if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)} } @@ -701,6 +778,26 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value) _node.ImagePrice4k = &value } + if value, ok := _c.mutation.SoraImagePrice360(); ok { + _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + _node.SoraImagePrice360 = &value + } + if value, ok := _c.mutation.SoraImagePrice540(); ok { + _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + _node.SoraImagePrice540 = &value + } + if value, ok := _c.mutation.SoraVideoPricePerRequest(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + _node.SoraVideoPricePerRequest = &value + } + if value, ok := _c.mutation.SoraVideoPricePerRequestHd(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + _node.SoraVideoPricePerRequestHd = &value + } + if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + _node.SoraStorageQuotaBytes = value + } if value, ok := _c.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) _node.ClaudeCodeOnly = value @@ -1177,6 +1274,120 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert { return u } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (u *GroupUpsert) SetSoraImagePrice360(v float64) *GroupUpsert { + u.Set(group.FieldSoraImagePrice360, v) + return u +} + +// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraImagePrice360() *GroupUpsert { + u.SetExcluded(group.FieldSoraImagePrice360) + return u +} + +// AddSoraImagePrice360 adds v to the "sora_image_price_360" field. +func (u *GroupUpsert) AddSoraImagePrice360(v float64) *GroupUpsert { + u.Add(group.FieldSoraImagePrice360, v) + return u +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (u *GroupUpsert) ClearSoraImagePrice360() *GroupUpsert { + u.SetNull(group.FieldSoraImagePrice360) + return u +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (u *GroupUpsert) SetSoraImagePrice540(v float64) *GroupUpsert { + u.Set(group.FieldSoraImagePrice540, v) + return u +} + +// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraImagePrice540() *GroupUpsert { + u.SetExcluded(group.FieldSoraImagePrice540) + return u +} + +// AddSoraImagePrice540 adds v to the "sora_image_price_540" field. +func (u *GroupUpsert) AddSoraImagePrice540(v float64) *GroupUpsert { + u.Add(group.FieldSoraImagePrice540, v) + return u +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (u *GroupUpsert) ClearSoraImagePrice540() *GroupUpsert { + u.SetNull(group.FieldSoraImagePrice540) + return u +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (u *GroupUpsert) SetSoraVideoPricePerRequest(v float64) *GroupUpsert { + u.Set(group.FieldSoraVideoPricePerRequest, v) + return u +} + +// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraVideoPricePerRequest() *GroupUpsert { + u.SetExcluded(group.FieldSoraVideoPricePerRequest) + return u +} + +// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field. +func (u *GroupUpsert) AddSoraVideoPricePerRequest(v float64) *GroupUpsert { + u.Add(group.FieldSoraVideoPricePerRequest, v) + return u +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (u *GroupUpsert) ClearSoraVideoPricePerRequest() *GroupUpsert { + u.SetNull(group.FieldSoraVideoPricePerRequest) + return u +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (u *GroupUpsert) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsert { + u.Set(group.FieldSoraVideoPricePerRequestHd, v) + return u +} + +// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraVideoPricePerRequestHd() *GroupUpsert { + u.SetExcluded(group.FieldSoraVideoPricePerRequestHd) + return u +} + +// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field. +func (u *GroupUpsert) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsert { + u.Add(group.FieldSoraVideoPricePerRequestHd, v) + return u +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (u *GroupUpsert) ClearSoraVideoPricePerRequestHd() *GroupUpsert { + u.SetNull(group.FieldSoraVideoPricePerRequestHd) + return u +} + +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *GroupUpsert) SetSoraStorageQuotaBytes(v int64) *GroupUpsert { + u.Set(group.FieldSoraStorageQuotaBytes, v) + return u +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraStorageQuotaBytes() *GroupUpsert { + u.SetExcluded(group.FieldSoraStorageQuotaBytes) + return u +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *GroupUpsert) AddSoraStorageQuotaBytes(v int64) *GroupUpsert { + u.Add(group.FieldSoraStorageQuotaBytes, v) + return u +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert { u.Set(group.FieldClaudeCodeOnly, v) @@ -1690,6 +1901,139 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne { }) } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (u *GroupUpsertOne) SetSoraImagePrice360(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraImagePrice360(v) + }) +} + +// AddSoraImagePrice360 adds v to the "sora_image_price_360" field. +func (u *GroupUpsertOne) AddSoraImagePrice360(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraImagePrice360(v) + }) +} + +// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraImagePrice360() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraImagePrice360() + }) +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (u *GroupUpsertOne) ClearSoraImagePrice360() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraImagePrice360() + }) +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (u *GroupUpsertOne) SetSoraImagePrice540(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraImagePrice540(v) + }) +} + +// AddSoraImagePrice540 adds v to the "sora_image_price_540" field. +func (u *GroupUpsertOne) AddSoraImagePrice540(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraImagePrice540(v) + }) +} + +// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraImagePrice540() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraImagePrice540() + }) +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (u *GroupUpsertOne) ClearSoraImagePrice540() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraImagePrice540() + }) +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (u *GroupUpsertOne) SetSoraVideoPricePerRequest(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraVideoPricePerRequest(v) + }) +} + +// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field. +func (u *GroupUpsertOne) AddSoraVideoPricePerRequest(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraVideoPricePerRequest(v) + }) +} + +// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequest() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraVideoPricePerRequest() + }) +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (u *GroupUpsertOne) ClearSoraVideoPricePerRequest() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraVideoPricePerRequest() + }) +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraVideoPricePerRequestHd(v) + }) +} + +// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraVideoPricePerRequestHd(v) + }) +} + +// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequestHd() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraVideoPricePerRequestHd() + }) +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertOne) ClearSoraVideoPricePerRequestHd() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraVideoPricePerRequestHd() + }) +} + +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *GroupUpsertOne) SetSoraStorageQuotaBytes(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraStorageQuotaBytes(v) + }) +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *GroupUpsertOne) AddSoraStorageQuotaBytes(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraStorageQuotaBytes(v) + }) +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraStorageQuotaBytes() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraStorageQuotaBytes() + }) +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne { return u.Update(func(s *GroupUpsert) { @@ -2391,6 +2735,139 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk { }) } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (u *GroupUpsertBulk) SetSoraImagePrice360(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraImagePrice360(v) + }) +} + +// AddSoraImagePrice360 adds v to the "sora_image_price_360" field. +func (u *GroupUpsertBulk) AddSoraImagePrice360(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraImagePrice360(v) + }) +} + +// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraImagePrice360() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraImagePrice360() + }) +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (u *GroupUpsertBulk) ClearSoraImagePrice360() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraImagePrice360() + }) +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (u *GroupUpsertBulk) SetSoraImagePrice540(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraImagePrice540(v) + }) +} + +// AddSoraImagePrice540 adds v to the "sora_image_price_540" field. +func (u *GroupUpsertBulk) AddSoraImagePrice540(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraImagePrice540(v) + }) +} + +// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraImagePrice540() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraImagePrice540() + }) +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (u *GroupUpsertBulk) ClearSoraImagePrice540() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraImagePrice540() + }) +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (u *GroupUpsertBulk) SetSoraVideoPricePerRequest(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraVideoPricePerRequest(v) + }) +} + +// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field. +func (u *GroupUpsertBulk) AddSoraVideoPricePerRequest(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraVideoPricePerRequest(v) + }) +} + +// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequest() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraVideoPricePerRequest() + }) +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequest() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraVideoPricePerRequest() + }) +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertBulk) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraVideoPricePerRequestHd(v) + }) +} + +// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertBulk) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraVideoPricePerRequestHd(v) + }) +} + +// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequestHd() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraVideoPricePerRequestHd() + }) +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequestHd() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraVideoPricePerRequestHd() + }) +} + +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *GroupUpsertBulk) SetSoraStorageQuotaBytes(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraStorageQuotaBytes(v) + }) +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *GroupUpsertBulk) AddSoraStorageQuotaBytes(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraStorageQuotaBytes(v) + }) +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraStorageQuotaBytes() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraStorageQuotaBytes() + }) +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk { return u.Update(func(s *GroupUpsert) { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index 393fd304..85575292 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -355,6 +355,135 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate { return _u } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (_u *GroupUpdate) SetSoraImagePrice360(v float64) *GroupUpdate { + _u.mutation.ResetSoraImagePrice360() + _u.mutation.SetSoraImagePrice360(v) + return _u +} + +// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraImagePrice360(v *float64) *GroupUpdate { + if v != nil { + _u.SetSoraImagePrice360(*v) + } + return _u +} + +// AddSoraImagePrice360 adds value to the "sora_image_price_360" field. +func (_u *GroupUpdate) AddSoraImagePrice360(v float64) *GroupUpdate { + _u.mutation.AddSoraImagePrice360(v) + return _u +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (_u *GroupUpdate) ClearSoraImagePrice360() *GroupUpdate { + _u.mutation.ClearSoraImagePrice360() + return _u +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (_u *GroupUpdate) SetSoraImagePrice540(v float64) *GroupUpdate { + _u.mutation.ResetSoraImagePrice540() + _u.mutation.SetSoraImagePrice540(v) + return _u +} + +// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraImagePrice540(v *float64) *GroupUpdate { + if v != nil { + _u.SetSoraImagePrice540(*v) + } + return _u +} + +// AddSoraImagePrice540 adds value to the "sora_image_price_540" field. +func (_u *GroupUpdate) AddSoraImagePrice540(v float64) *GroupUpdate { + _u.mutation.AddSoraImagePrice540(v) + return _u +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (_u *GroupUpdate) ClearSoraImagePrice540() *GroupUpdate { + _u.mutation.ClearSoraImagePrice540() + return _u +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (_u *GroupUpdate) SetSoraVideoPricePerRequest(v float64) *GroupUpdate { + _u.mutation.ResetSoraVideoPricePerRequest() + _u.mutation.SetSoraVideoPricePerRequest(v) + return _u +} + +// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdate { + if v != nil { + _u.SetSoraVideoPricePerRequest(*v) + } + return _u +} + +// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field. +func (_u *GroupUpdate) AddSoraVideoPricePerRequest(v float64) *GroupUpdate { + _u.mutation.AddSoraVideoPricePerRequest(v) + return _u +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (_u *GroupUpdate) ClearSoraVideoPricePerRequest() *GroupUpdate { + _u.mutation.ClearSoraVideoPricePerRequest() + return _u +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdate) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdate { + _u.mutation.ResetSoraVideoPricePerRequestHd() + _u.mutation.SetSoraVideoPricePerRequestHd(v) + return _u +} + +// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdate { + if v != nil { + _u.SetSoraVideoPricePerRequestHd(*v) + } + return _u +} + +// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdate) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdate { + _u.mutation.AddSoraVideoPricePerRequestHd(v) + return _u +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdate) ClearSoraVideoPricePerRequestHd() *GroupUpdate { + _u.mutation.ClearSoraVideoPricePerRequestHd() + return _u +} + +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_u *GroupUpdate) SetSoraStorageQuotaBytes(v int64) *GroupUpdate { + _u.mutation.ResetSoraStorageQuotaBytes() + _u.mutation.SetSoraStorageQuotaBytes(v) + return _u +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdate { + if v != nil { + _u.SetSoraStorageQuotaBytes(*v) + } + return _u +} + +// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. +func (_u *GroupUpdate) AddSoraStorageQuotaBytes(v int64) *GroupUpdate { + _u.mutation.AddSoraStorageQuotaBytes(v) + return _u +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate { _u.mutation.SetClaudeCodeOnly(v) @@ -892,6 +1021,48 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.ImagePrice4kCleared() { _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) } + if value, ok := _u.mutation.SoraImagePrice360(); ok { + _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraImagePrice360(); ok { + _spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + } + if _u.mutation.SoraImagePrice360Cleared() { + _spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraImagePrice540(); ok { + _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraImagePrice540(); ok { + _spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + } + if _u.mutation.SoraImagePrice540Cleared() { + _spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok { + _spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + } + if _u.mutation.SoraVideoPricePerRequestCleared() { + _spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok { + _spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + } + if _u.mutation.SoraVideoPricePerRequestHdCleared() { + _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { + _spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } if value, ok := _u.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) } @@ -1573,6 +1744,135 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne { return _u } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (_u *GroupUpdateOne) SetSoraImagePrice360(v float64) *GroupUpdateOne { + _u.mutation.ResetSoraImagePrice360() + _u.mutation.SetSoraImagePrice360(v) + return _u +} + +// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraImagePrice360(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetSoraImagePrice360(*v) + } + return _u +} + +// AddSoraImagePrice360 adds value to the "sora_image_price_360" field. +func (_u *GroupUpdateOne) AddSoraImagePrice360(v float64) *GroupUpdateOne { + _u.mutation.AddSoraImagePrice360(v) + return _u +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (_u *GroupUpdateOne) ClearSoraImagePrice360() *GroupUpdateOne { + _u.mutation.ClearSoraImagePrice360() + return _u +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (_u *GroupUpdateOne) SetSoraImagePrice540(v float64) *GroupUpdateOne { + _u.mutation.ResetSoraImagePrice540() + _u.mutation.SetSoraImagePrice540(v) + return _u +} + +// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraImagePrice540(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetSoraImagePrice540(*v) + } + return _u +} + +// AddSoraImagePrice540 adds value to the "sora_image_price_540" field. +func (_u *GroupUpdateOne) AddSoraImagePrice540(v float64) *GroupUpdateOne { + _u.mutation.AddSoraImagePrice540(v) + return _u +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (_u *GroupUpdateOne) ClearSoraImagePrice540() *GroupUpdateOne { + _u.mutation.ClearSoraImagePrice540() + return _u +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (_u *GroupUpdateOne) SetSoraVideoPricePerRequest(v float64) *GroupUpdateOne { + _u.mutation.ResetSoraVideoPricePerRequest() + _u.mutation.SetSoraVideoPricePerRequest(v) + return _u +} + +// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetSoraVideoPricePerRequest(*v) + } + return _u +} + +// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field. +func (_u *GroupUpdateOne) AddSoraVideoPricePerRequest(v float64) *GroupUpdateOne { + _u.mutation.AddSoraVideoPricePerRequest(v) + return _u +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequest() *GroupUpdateOne { + _u.mutation.ClearSoraVideoPricePerRequest() + return _u +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdateOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne { + _u.mutation.ResetSoraVideoPricePerRequestHd() + _u.mutation.SetSoraVideoPricePerRequestHd(v) + return _u +} + +// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetSoraVideoPricePerRequestHd(*v) + } + return _u +} + +// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdateOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne { + _u.mutation.AddSoraVideoPricePerRequestHd(v) + return _u +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequestHd() *GroupUpdateOne { + _u.mutation.ClearSoraVideoPricePerRequestHd() + return _u +} + +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_u *GroupUpdateOne) SetSoraStorageQuotaBytes(v int64) *GroupUpdateOne { + _u.mutation.ResetSoraStorageQuotaBytes() + _u.mutation.SetSoraStorageQuotaBytes(v) + return _u +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdateOne { + if v != nil { + _u.SetSoraStorageQuotaBytes(*v) + } + return _u +} + +// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. +func (_u *GroupUpdateOne) AddSoraStorageQuotaBytes(v int64) *GroupUpdateOne { + _u.mutation.AddSoraStorageQuotaBytes(v) + return _u +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne { _u.mutation.SetClaudeCodeOnly(v) @@ -2140,6 +2440,48 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if _u.mutation.ImagePrice4kCleared() { _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) } + if value, ok := _u.mutation.SoraImagePrice360(); ok { + _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraImagePrice360(); ok { + _spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + } + if _u.mutation.SoraImagePrice360Cleared() { + _spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraImagePrice540(); ok { + _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraImagePrice540(); ok { + _spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + } + if _u.mutation.SoraImagePrice540Cleared() { + _spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok { + _spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + } + if _u.mutation.SoraVideoPricePerRequestCleared() { + _spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok { + _spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + } + if _u.mutation.SoraVideoPricePerRequestHdCleared() { + _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { + _spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } if value, ok := _u.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) } diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go index 1b15685c..49d7f3c5 100644 --- a/backend/ent/hook/hook.go +++ b/backend/ent/hook/hook.go @@ -93,6 +93,18 @@ func (f GroupFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.GroupMutation", m) } +// The IdempotencyRecordFunc type is an adapter to allow the use of ordinary +// function as IdempotencyRecord mutator. +type IdempotencyRecordFunc func(context.Context, *ent.IdempotencyRecordMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f IdempotencyRecordFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.IdempotencyRecordMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdempotencyRecordMutation", m) +} + // The PromoCodeFunc type is an adapter to allow the use of ordinary // function as PromoCode mutator. type PromoCodeFunc func(context.Context, *ent.PromoCodeMutation) (ent.Value, error) @@ -141,6 +153,18 @@ func (f RedeemCodeFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.RedeemCodeMutation", m) } +// The SecuritySecretFunc type is an adapter to allow the use of ordinary +// function as SecuritySecret mutator. +type SecuritySecretFunc func(context.Context, *ent.SecuritySecretMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f SecuritySecretFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.SecuritySecretMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SecuritySecretMutation", m) +} + // The SettingFunc type is an adapter to allow the use of ordinary // function as Setting mutator. type SettingFunc func(context.Context, *ent.SettingMutation) (ent.Value, error) diff --git a/backend/ent/idempotencyrecord.go b/backend/ent/idempotencyrecord.go new file mode 100644 index 00000000..ab120f8f --- /dev/null +++ b/backend/ent/idempotencyrecord.go @@ -0,0 +1,228 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" +) + +// IdempotencyRecord is the model entity for the IdempotencyRecord schema. +type IdempotencyRecord struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // Scope holds the value of the "scope" field. + Scope string `json:"scope,omitempty"` + // IdempotencyKeyHash holds the value of the "idempotency_key_hash" field. + IdempotencyKeyHash string `json:"idempotency_key_hash,omitempty"` + // RequestFingerprint holds the value of the "request_fingerprint" field. + RequestFingerprint string `json:"request_fingerprint,omitempty"` + // Status holds the value of the "status" field. + Status string `json:"status,omitempty"` + // ResponseStatus holds the value of the "response_status" field. + ResponseStatus *int `json:"response_status,omitempty"` + // ResponseBody holds the value of the "response_body" field. + ResponseBody *string `json:"response_body,omitempty"` + // ErrorReason holds the value of the "error_reason" field. + ErrorReason *string `json:"error_reason,omitempty"` + // LockedUntil holds the value of the "locked_until" field. + LockedUntil *time.Time `json:"locked_until,omitempty"` + // ExpiresAt holds the value of the "expires_at" field. + ExpiresAt time.Time `json:"expires_at,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*IdempotencyRecord) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case idempotencyrecord.FieldID, idempotencyrecord.FieldResponseStatus: + values[i] = new(sql.NullInt64) + case idempotencyrecord.FieldScope, idempotencyrecord.FieldIdempotencyKeyHash, idempotencyrecord.FieldRequestFingerprint, idempotencyrecord.FieldStatus, idempotencyrecord.FieldResponseBody, idempotencyrecord.FieldErrorReason: + values[i] = new(sql.NullString) + case idempotencyrecord.FieldCreatedAt, idempotencyrecord.FieldUpdatedAt, idempotencyrecord.FieldLockedUntil, idempotencyrecord.FieldExpiresAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the IdempotencyRecord fields. +func (_m *IdempotencyRecord) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case idempotencyrecord.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case idempotencyrecord.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case idempotencyrecord.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case idempotencyrecord.FieldScope: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field scope", values[i]) + } else if value.Valid { + _m.Scope = value.String + } + case idempotencyrecord.FieldIdempotencyKeyHash: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field idempotency_key_hash", values[i]) + } else if value.Valid { + _m.IdempotencyKeyHash = value.String + } + case idempotencyrecord.FieldRequestFingerprint: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field request_fingerprint", values[i]) + } else if value.Valid { + _m.RequestFingerprint = value.String + } + case idempotencyrecord.FieldStatus: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field status", values[i]) + } else if value.Valid { + _m.Status = value.String + } + case idempotencyrecord.FieldResponseStatus: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field response_status", values[i]) + } else if value.Valid { + _m.ResponseStatus = new(int) + *_m.ResponseStatus = int(value.Int64) + } + case idempotencyrecord.FieldResponseBody: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field response_body", values[i]) + } else if value.Valid { + _m.ResponseBody = new(string) + *_m.ResponseBody = value.String + } + case idempotencyrecord.FieldErrorReason: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field error_reason", values[i]) + } else if value.Valid { + _m.ErrorReason = new(string) + *_m.ErrorReason = value.String + } + case idempotencyrecord.FieldLockedUntil: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field locked_until", values[i]) + } else if value.Valid { + _m.LockedUntil = new(time.Time) + *_m.LockedUntil = value.Time + } + case idempotencyrecord.FieldExpiresAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field expires_at", values[i]) + } else if value.Valid { + _m.ExpiresAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the IdempotencyRecord. +// This includes values selected through modifiers, order, etc. +func (_m *IdempotencyRecord) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// Update returns a builder for updating this IdempotencyRecord. +// Note that you need to call IdempotencyRecord.Unwrap() before calling this method if this IdempotencyRecord +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *IdempotencyRecord) Update() *IdempotencyRecordUpdateOne { + return NewIdempotencyRecordClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the IdempotencyRecord entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *IdempotencyRecord) Unwrap() *IdempotencyRecord { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: IdempotencyRecord is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *IdempotencyRecord) String() string { + var builder strings.Builder + builder.WriteString("IdempotencyRecord(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("scope=") + builder.WriteString(_m.Scope) + builder.WriteString(", ") + builder.WriteString("idempotency_key_hash=") + builder.WriteString(_m.IdempotencyKeyHash) + builder.WriteString(", ") + builder.WriteString("request_fingerprint=") + builder.WriteString(_m.RequestFingerprint) + builder.WriteString(", ") + builder.WriteString("status=") + builder.WriteString(_m.Status) + builder.WriteString(", ") + if v := _m.ResponseStatus; v != nil { + builder.WriteString("response_status=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.ResponseBody; v != nil { + builder.WriteString("response_body=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.ErrorReason; v != nil { + builder.WriteString("error_reason=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.LockedUntil; v != nil { + builder.WriteString("locked_until=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("expires_at=") + builder.WriteString(_m.ExpiresAt.Format(time.ANSIC)) + builder.WriteByte(')') + return builder.String() +} + +// IdempotencyRecords is a parsable slice of IdempotencyRecord. +type IdempotencyRecords []*IdempotencyRecord diff --git a/backend/ent/idempotencyrecord/idempotencyrecord.go b/backend/ent/idempotencyrecord/idempotencyrecord.go new file mode 100644 index 00000000..d9686f60 --- /dev/null +++ b/backend/ent/idempotencyrecord/idempotencyrecord.go @@ -0,0 +1,148 @@ +// Code generated by ent, DO NOT EDIT. + +package idempotencyrecord + +import ( + "time" + + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the idempotencyrecord type in the database. + Label = "idempotency_record" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldScope holds the string denoting the scope field in the database. + FieldScope = "scope" + // FieldIdempotencyKeyHash holds the string denoting the idempotency_key_hash field in the database. + FieldIdempotencyKeyHash = "idempotency_key_hash" + // FieldRequestFingerprint holds the string denoting the request_fingerprint field in the database. + FieldRequestFingerprint = "request_fingerprint" + // FieldStatus holds the string denoting the status field in the database. + FieldStatus = "status" + // FieldResponseStatus holds the string denoting the response_status field in the database. + FieldResponseStatus = "response_status" + // FieldResponseBody holds the string denoting the response_body field in the database. + FieldResponseBody = "response_body" + // FieldErrorReason holds the string denoting the error_reason field in the database. + FieldErrorReason = "error_reason" + // FieldLockedUntil holds the string denoting the locked_until field in the database. + FieldLockedUntil = "locked_until" + // FieldExpiresAt holds the string denoting the expires_at field in the database. + FieldExpiresAt = "expires_at" + // Table holds the table name of the idempotencyrecord in the database. + Table = "idempotency_records" +) + +// Columns holds all SQL columns for idempotencyrecord fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldScope, + FieldIdempotencyKeyHash, + FieldRequestFingerprint, + FieldStatus, + FieldResponseStatus, + FieldResponseBody, + FieldErrorReason, + FieldLockedUntil, + FieldExpiresAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // ScopeValidator is a validator for the "scope" field. It is called by the builders before save. + ScopeValidator func(string) error + // IdempotencyKeyHashValidator is a validator for the "idempotency_key_hash" field. It is called by the builders before save. + IdempotencyKeyHashValidator func(string) error + // RequestFingerprintValidator is a validator for the "request_fingerprint" field. It is called by the builders before save. + RequestFingerprintValidator func(string) error + // StatusValidator is a validator for the "status" field. It is called by the builders before save. + StatusValidator func(string) error + // ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save. + ErrorReasonValidator func(string) error +) + +// OrderOption defines the ordering options for the IdempotencyRecord queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByScope orders the results by the scope field. +func ByScope(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScope, opts...).ToFunc() +} + +// ByIdempotencyKeyHash orders the results by the idempotency_key_hash field. +func ByIdempotencyKeyHash(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIdempotencyKeyHash, opts...).ToFunc() +} + +// ByRequestFingerprint orders the results by the request_fingerprint field. +func ByRequestFingerprint(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRequestFingerprint, opts...).ToFunc() +} + +// ByStatus orders the results by the status field. +func ByStatus(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStatus, opts...).ToFunc() +} + +// ByResponseStatus orders the results by the response_status field. +func ByResponseStatus(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldResponseStatus, opts...).ToFunc() +} + +// ByResponseBody orders the results by the response_body field. +func ByResponseBody(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldResponseBody, opts...).ToFunc() +} + +// ByErrorReason orders the results by the error_reason field. +func ByErrorReason(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldErrorReason, opts...).ToFunc() +} + +// ByLockedUntil orders the results by the locked_until field. +func ByLockedUntil(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLockedUntil, opts...).ToFunc() +} + +// ByExpiresAt orders the results by the expires_at field. +func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldExpiresAt, opts...).ToFunc() +} diff --git a/backend/ent/idempotencyrecord/where.go b/backend/ent/idempotencyrecord/where.go new file mode 100644 index 00000000..c3d8d9d5 --- /dev/null +++ b/backend/ent/idempotencyrecord/where.go @@ -0,0 +1,755 @@ +// Code generated by ent, DO NOT EDIT. + +package idempotencyrecord + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// Scope applies equality check predicate on the "scope" field. It's identical to ScopeEQ. +func Scope(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldScope, v)) +} + +// IdempotencyKeyHash applies equality check predicate on the "idempotency_key_hash" field. It's identical to IdempotencyKeyHashEQ. +func IdempotencyKeyHash(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldIdempotencyKeyHash, v)) +} + +// RequestFingerprint applies equality check predicate on the "request_fingerprint" field. It's identical to RequestFingerprintEQ. +func RequestFingerprint(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldRequestFingerprint, v)) +} + +// Status applies equality check predicate on the "status" field. It's identical to StatusEQ. +func Status(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldStatus, v)) +} + +// ResponseStatus applies equality check predicate on the "response_status" field. It's identical to ResponseStatusEQ. +func ResponseStatus(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseStatus, v)) +} + +// ResponseBody applies equality check predicate on the "response_body" field. It's identical to ResponseBodyEQ. +func ResponseBody(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseBody, v)) +} + +// ErrorReason applies equality check predicate on the "error_reason" field. It's identical to ErrorReasonEQ. +func ErrorReason(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldErrorReason, v)) +} + +// LockedUntil applies equality check predicate on the "locked_until" field. It's identical to LockedUntilEQ. +func LockedUntil(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldLockedUntil, v)) +} + +// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ. +func ExpiresAt(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldExpiresAt, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// ScopeEQ applies the EQ predicate on the "scope" field. +func ScopeEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldScope, v)) +} + +// ScopeNEQ applies the NEQ predicate on the "scope" field. +func ScopeNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldScope, v)) +} + +// ScopeIn applies the In predicate on the "scope" field. +func ScopeIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldScope, vs...)) +} + +// ScopeNotIn applies the NotIn predicate on the "scope" field. +func ScopeNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldScope, vs...)) +} + +// ScopeGT applies the GT predicate on the "scope" field. +func ScopeGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldScope, v)) +} + +// ScopeGTE applies the GTE predicate on the "scope" field. +func ScopeGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldScope, v)) +} + +// ScopeLT applies the LT predicate on the "scope" field. +func ScopeLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldScope, v)) +} + +// ScopeLTE applies the LTE predicate on the "scope" field. +func ScopeLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldScope, v)) +} + +// ScopeContains applies the Contains predicate on the "scope" field. +func ScopeContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldScope, v)) +} + +// ScopeHasPrefix applies the HasPrefix predicate on the "scope" field. +func ScopeHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldScope, v)) +} + +// ScopeHasSuffix applies the HasSuffix predicate on the "scope" field. +func ScopeHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldScope, v)) +} + +// ScopeEqualFold applies the EqualFold predicate on the "scope" field. +func ScopeEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldScope, v)) +} + +// ScopeContainsFold applies the ContainsFold predicate on the "scope" field. +func ScopeContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldScope, v)) +} + +// IdempotencyKeyHashEQ applies the EQ predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashNEQ applies the NEQ predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashIn applies the In predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldIdempotencyKeyHash, vs...)) +} + +// IdempotencyKeyHashNotIn applies the NotIn predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldIdempotencyKeyHash, vs...)) +} + +// IdempotencyKeyHashGT applies the GT predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashGTE applies the GTE predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashLT applies the LT predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashLTE applies the LTE predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashContains applies the Contains predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashHasPrefix applies the HasPrefix predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashHasSuffix applies the HasSuffix predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashEqualFold applies the EqualFold predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashContainsFold applies the ContainsFold predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldIdempotencyKeyHash, v)) +} + +// RequestFingerprintEQ applies the EQ predicate on the "request_fingerprint" field. +func RequestFingerprintEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldRequestFingerprint, v)) +} + +// RequestFingerprintNEQ applies the NEQ predicate on the "request_fingerprint" field. +func RequestFingerprintNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldRequestFingerprint, v)) +} + +// RequestFingerprintIn applies the In predicate on the "request_fingerprint" field. +func RequestFingerprintIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldRequestFingerprint, vs...)) +} + +// RequestFingerprintNotIn applies the NotIn predicate on the "request_fingerprint" field. +func RequestFingerprintNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldRequestFingerprint, vs...)) +} + +// RequestFingerprintGT applies the GT predicate on the "request_fingerprint" field. +func RequestFingerprintGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldRequestFingerprint, v)) +} + +// RequestFingerprintGTE applies the GTE predicate on the "request_fingerprint" field. +func RequestFingerprintGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldRequestFingerprint, v)) +} + +// RequestFingerprintLT applies the LT predicate on the "request_fingerprint" field. +func RequestFingerprintLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldRequestFingerprint, v)) +} + +// RequestFingerprintLTE applies the LTE predicate on the "request_fingerprint" field. +func RequestFingerprintLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldRequestFingerprint, v)) +} + +// RequestFingerprintContains applies the Contains predicate on the "request_fingerprint" field. +func RequestFingerprintContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldRequestFingerprint, v)) +} + +// RequestFingerprintHasPrefix applies the HasPrefix predicate on the "request_fingerprint" field. +func RequestFingerprintHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldRequestFingerprint, v)) +} + +// RequestFingerprintHasSuffix applies the HasSuffix predicate on the "request_fingerprint" field. +func RequestFingerprintHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldRequestFingerprint, v)) +} + +// RequestFingerprintEqualFold applies the EqualFold predicate on the "request_fingerprint" field. +func RequestFingerprintEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldRequestFingerprint, v)) +} + +// RequestFingerprintContainsFold applies the ContainsFold predicate on the "request_fingerprint" field. +func RequestFingerprintContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldRequestFingerprint, v)) +} + +// StatusEQ applies the EQ predicate on the "status" field. +func StatusEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldStatus, v)) +} + +// StatusNEQ applies the NEQ predicate on the "status" field. +func StatusNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldStatus, v)) +} + +// StatusIn applies the In predicate on the "status" field. +func StatusIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldStatus, vs...)) +} + +// StatusNotIn applies the NotIn predicate on the "status" field. +func StatusNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldStatus, vs...)) +} + +// StatusGT applies the GT predicate on the "status" field. +func StatusGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldStatus, v)) +} + +// StatusGTE applies the GTE predicate on the "status" field. +func StatusGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldStatus, v)) +} + +// StatusLT applies the LT predicate on the "status" field. +func StatusLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldStatus, v)) +} + +// StatusLTE applies the LTE predicate on the "status" field. +func StatusLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldStatus, v)) +} + +// StatusContains applies the Contains predicate on the "status" field. +func StatusContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldStatus, v)) +} + +// StatusHasPrefix applies the HasPrefix predicate on the "status" field. +func StatusHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldStatus, v)) +} + +// StatusHasSuffix applies the HasSuffix predicate on the "status" field. +func StatusHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldStatus, v)) +} + +// StatusEqualFold applies the EqualFold predicate on the "status" field. +func StatusEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldStatus, v)) +} + +// StatusContainsFold applies the ContainsFold predicate on the "status" field. +func StatusContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldStatus, v)) +} + +// ResponseStatusEQ applies the EQ predicate on the "response_status" field. +func ResponseStatusEQ(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseStatus, v)) +} + +// ResponseStatusNEQ applies the NEQ predicate on the "response_status" field. +func ResponseStatusNEQ(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldResponseStatus, v)) +} + +// ResponseStatusIn applies the In predicate on the "response_status" field. +func ResponseStatusIn(vs ...int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldResponseStatus, vs...)) +} + +// ResponseStatusNotIn applies the NotIn predicate on the "response_status" field. +func ResponseStatusNotIn(vs ...int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldResponseStatus, vs...)) +} + +// ResponseStatusGT applies the GT predicate on the "response_status" field. +func ResponseStatusGT(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldResponseStatus, v)) +} + +// ResponseStatusGTE applies the GTE predicate on the "response_status" field. +func ResponseStatusGTE(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldResponseStatus, v)) +} + +// ResponseStatusLT applies the LT predicate on the "response_status" field. +func ResponseStatusLT(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldResponseStatus, v)) +} + +// ResponseStatusLTE applies the LTE predicate on the "response_status" field. +func ResponseStatusLTE(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldResponseStatus, v)) +} + +// ResponseStatusIsNil applies the IsNil predicate on the "response_status" field. +func ResponseStatusIsNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIsNull(FieldResponseStatus)) +} + +// ResponseStatusNotNil applies the NotNil predicate on the "response_status" field. +func ResponseStatusNotNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotNull(FieldResponseStatus)) +} + +// ResponseBodyEQ applies the EQ predicate on the "response_body" field. +func ResponseBodyEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseBody, v)) +} + +// ResponseBodyNEQ applies the NEQ predicate on the "response_body" field. +func ResponseBodyNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldResponseBody, v)) +} + +// ResponseBodyIn applies the In predicate on the "response_body" field. +func ResponseBodyIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldResponseBody, vs...)) +} + +// ResponseBodyNotIn applies the NotIn predicate on the "response_body" field. +func ResponseBodyNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldResponseBody, vs...)) +} + +// ResponseBodyGT applies the GT predicate on the "response_body" field. +func ResponseBodyGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldResponseBody, v)) +} + +// ResponseBodyGTE applies the GTE predicate on the "response_body" field. +func ResponseBodyGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldResponseBody, v)) +} + +// ResponseBodyLT applies the LT predicate on the "response_body" field. +func ResponseBodyLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldResponseBody, v)) +} + +// ResponseBodyLTE applies the LTE predicate on the "response_body" field. +func ResponseBodyLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldResponseBody, v)) +} + +// ResponseBodyContains applies the Contains predicate on the "response_body" field. +func ResponseBodyContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldResponseBody, v)) +} + +// ResponseBodyHasPrefix applies the HasPrefix predicate on the "response_body" field. +func ResponseBodyHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldResponseBody, v)) +} + +// ResponseBodyHasSuffix applies the HasSuffix predicate on the "response_body" field. +func ResponseBodyHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldResponseBody, v)) +} + +// ResponseBodyIsNil applies the IsNil predicate on the "response_body" field. +func ResponseBodyIsNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIsNull(FieldResponseBody)) +} + +// ResponseBodyNotNil applies the NotNil predicate on the "response_body" field. +func ResponseBodyNotNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotNull(FieldResponseBody)) +} + +// ResponseBodyEqualFold applies the EqualFold predicate on the "response_body" field. +func ResponseBodyEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldResponseBody, v)) +} + +// ResponseBodyContainsFold applies the ContainsFold predicate on the "response_body" field. +func ResponseBodyContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldResponseBody, v)) +} + +// ErrorReasonEQ applies the EQ predicate on the "error_reason" field. +func ErrorReasonEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldErrorReason, v)) +} + +// ErrorReasonNEQ applies the NEQ predicate on the "error_reason" field. +func ErrorReasonNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldErrorReason, v)) +} + +// ErrorReasonIn applies the In predicate on the "error_reason" field. +func ErrorReasonIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldErrorReason, vs...)) +} + +// ErrorReasonNotIn applies the NotIn predicate on the "error_reason" field. +func ErrorReasonNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldErrorReason, vs...)) +} + +// ErrorReasonGT applies the GT predicate on the "error_reason" field. +func ErrorReasonGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldErrorReason, v)) +} + +// ErrorReasonGTE applies the GTE predicate on the "error_reason" field. +func ErrorReasonGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldErrorReason, v)) +} + +// ErrorReasonLT applies the LT predicate on the "error_reason" field. +func ErrorReasonLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldErrorReason, v)) +} + +// ErrorReasonLTE applies the LTE predicate on the "error_reason" field. +func ErrorReasonLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldErrorReason, v)) +} + +// ErrorReasonContains applies the Contains predicate on the "error_reason" field. +func ErrorReasonContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldErrorReason, v)) +} + +// ErrorReasonHasPrefix applies the HasPrefix predicate on the "error_reason" field. +func ErrorReasonHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldErrorReason, v)) +} + +// ErrorReasonHasSuffix applies the HasSuffix predicate on the "error_reason" field. +func ErrorReasonHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldErrorReason, v)) +} + +// ErrorReasonIsNil applies the IsNil predicate on the "error_reason" field. +func ErrorReasonIsNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIsNull(FieldErrorReason)) +} + +// ErrorReasonNotNil applies the NotNil predicate on the "error_reason" field. +func ErrorReasonNotNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotNull(FieldErrorReason)) +} + +// ErrorReasonEqualFold applies the EqualFold predicate on the "error_reason" field. +func ErrorReasonEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldErrorReason, v)) +} + +// ErrorReasonContainsFold applies the ContainsFold predicate on the "error_reason" field. +func ErrorReasonContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldErrorReason, v)) +} + +// LockedUntilEQ applies the EQ predicate on the "locked_until" field. +func LockedUntilEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldLockedUntil, v)) +} + +// LockedUntilNEQ applies the NEQ predicate on the "locked_until" field. +func LockedUntilNEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldLockedUntil, v)) +} + +// LockedUntilIn applies the In predicate on the "locked_until" field. +func LockedUntilIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldLockedUntil, vs...)) +} + +// LockedUntilNotIn applies the NotIn predicate on the "locked_until" field. +func LockedUntilNotIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldLockedUntil, vs...)) +} + +// LockedUntilGT applies the GT predicate on the "locked_until" field. +func LockedUntilGT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldLockedUntil, v)) +} + +// LockedUntilGTE applies the GTE predicate on the "locked_until" field. +func LockedUntilGTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldLockedUntil, v)) +} + +// LockedUntilLT applies the LT predicate on the "locked_until" field. +func LockedUntilLT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldLockedUntil, v)) +} + +// LockedUntilLTE applies the LTE predicate on the "locked_until" field. +func LockedUntilLTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldLockedUntil, v)) +} + +// LockedUntilIsNil applies the IsNil predicate on the "locked_until" field. +func LockedUntilIsNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIsNull(FieldLockedUntil)) +} + +// LockedUntilNotNil applies the NotNil predicate on the "locked_until" field. +func LockedUntilNotNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotNull(FieldLockedUntil)) +} + +// ExpiresAtEQ applies the EQ predicate on the "expires_at" field. +func ExpiresAtEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldExpiresAt, v)) +} + +// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field. +func ExpiresAtNEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldExpiresAt, v)) +} + +// ExpiresAtIn applies the In predicate on the "expires_at" field. +func ExpiresAtIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field. +func ExpiresAtNotIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtGT applies the GT predicate on the "expires_at" field. +func ExpiresAtGT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldExpiresAt, v)) +} + +// ExpiresAtGTE applies the GTE predicate on the "expires_at" field. +func ExpiresAtGTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldExpiresAt, v)) +} + +// ExpiresAtLT applies the LT predicate on the "expires_at" field. +func ExpiresAtLT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldExpiresAt, v)) +} + +// ExpiresAtLTE applies the LTE predicate on the "expires_at" field. +func ExpiresAtLTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldExpiresAt, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.IdempotencyRecord) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.IdempotencyRecord) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.IdempotencyRecord) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.NotPredicates(p)) +} diff --git a/backend/ent/idempotencyrecord_create.go b/backend/ent/idempotencyrecord_create.go new file mode 100644 index 00000000..bf4deaf2 --- /dev/null +++ b/backend/ent/idempotencyrecord_create.go @@ -0,0 +1,1132 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" +) + +// IdempotencyRecordCreate is the builder for creating a IdempotencyRecord entity. +type IdempotencyRecordCreate struct { + config + mutation *IdempotencyRecordMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *IdempotencyRecordCreate) SetCreatedAt(v time.Time) *IdempotencyRecordCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableCreatedAt(v *time.Time) *IdempotencyRecordCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *IdempotencyRecordCreate) SetUpdatedAt(v time.Time) *IdempotencyRecordCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableUpdatedAt(v *time.Time) *IdempotencyRecordCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetScope sets the "scope" field. +func (_c *IdempotencyRecordCreate) SetScope(v string) *IdempotencyRecordCreate { + _c.mutation.SetScope(v) + return _c +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (_c *IdempotencyRecordCreate) SetIdempotencyKeyHash(v string) *IdempotencyRecordCreate { + _c.mutation.SetIdempotencyKeyHash(v) + return _c +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (_c *IdempotencyRecordCreate) SetRequestFingerprint(v string) *IdempotencyRecordCreate { + _c.mutation.SetRequestFingerprint(v) + return _c +} + +// SetStatus sets the "status" field. +func (_c *IdempotencyRecordCreate) SetStatus(v string) *IdempotencyRecordCreate { + _c.mutation.SetStatus(v) + return _c +} + +// SetResponseStatus sets the "response_status" field. +func (_c *IdempotencyRecordCreate) SetResponseStatus(v int) *IdempotencyRecordCreate { + _c.mutation.SetResponseStatus(v) + return _c +} + +// SetNillableResponseStatus sets the "response_status" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableResponseStatus(v *int) *IdempotencyRecordCreate { + if v != nil { + _c.SetResponseStatus(*v) + } + return _c +} + +// SetResponseBody sets the "response_body" field. +func (_c *IdempotencyRecordCreate) SetResponseBody(v string) *IdempotencyRecordCreate { + _c.mutation.SetResponseBody(v) + return _c +} + +// SetNillableResponseBody sets the "response_body" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableResponseBody(v *string) *IdempotencyRecordCreate { + if v != nil { + _c.SetResponseBody(*v) + } + return _c +} + +// SetErrorReason sets the "error_reason" field. +func (_c *IdempotencyRecordCreate) SetErrorReason(v string) *IdempotencyRecordCreate { + _c.mutation.SetErrorReason(v) + return _c +} + +// SetNillableErrorReason sets the "error_reason" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableErrorReason(v *string) *IdempotencyRecordCreate { + if v != nil { + _c.SetErrorReason(*v) + } + return _c +} + +// SetLockedUntil sets the "locked_until" field. +func (_c *IdempotencyRecordCreate) SetLockedUntil(v time.Time) *IdempotencyRecordCreate { + _c.mutation.SetLockedUntil(v) + return _c +} + +// SetNillableLockedUntil sets the "locked_until" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableLockedUntil(v *time.Time) *IdempotencyRecordCreate { + if v != nil { + _c.SetLockedUntil(*v) + } + return _c +} + +// SetExpiresAt sets the "expires_at" field. +func (_c *IdempotencyRecordCreate) SetExpiresAt(v time.Time) *IdempotencyRecordCreate { + _c.mutation.SetExpiresAt(v) + return _c +} + +// Mutation returns the IdempotencyRecordMutation object of the builder. +func (_c *IdempotencyRecordCreate) Mutation() *IdempotencyRecordMutation { + return _c.mutation +} + +// Save creates the IdempotencyRecord in the database. +func (_c *IdempotencyRecordCreate) Save(ctx context.Context) (*IdempotencyRecord, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *IdempotencyRecordCreate) SaveX(ctx context.Context) *IdempotencyRecord { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *IdempotencyRecordCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *IdempotencyRecordCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *IdempotencyRecordCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := idempotencyrecord.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := idempotencyrecord.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *IdempotencyRecordCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "IdempotencyRecord.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "IdempotencyRecord.updated_at"`)} + } + if _, ok := _c.mutation.Scope(); !ok { + return &ValidationError{Name: "scope", err: errors.New(`ent: missing required field "IdempotencyRecord.scope"`)} + } + if v, ok := _c.mutation.Scope(); ok { + if err := idempotencyrecord.ScopeValidator(v); err != nil { + return &ValidationError{Name: "scope", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.scope": %w`, err)} + } + } + if _, ok := _c.mutation.IdempotencyKeyHash(); !ok { + return &ValidationError{Name: "idempotency_key_hash", err: errors.New(`ent: missing required field "IdempotencyRecord.idempotency_key_hash"`)} + } + if v, ok := _c.mutation.IdempotencyKeyHash(); ok { + if err := idempotencyrecord.IdempotencyKeyHashValidator(v); err != nil { + return &ValidationError{Name: "idempotency_key_hash", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.idempotency_key_hash": %w`, err)} + } + } + if _, ok := _c.mutation.RequestFingerprint(); !ok { + return &ValidationError{Name: "request_fingerprint", err: errors.New(`ent: missing required field "IdempotencyRecord.request_fingerprint"`)} + } + if v, ok := _c.mutation.RequestFingerprint(); ok { + if err := idempotencyrecord.RequestFingerprintValidator(v); err != nil { + return &ValidationError{Name: "request_fingerprint", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.request_fingerprint": %w`, err)} + } + } + if _, ok := _c.mutation.Status(); !ok { + return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "IdempotencyRecord.status"`)} + } + if v, ok := _c.mutation.Status(); ok { + if err := idempotencyrecord.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.status": %w`, err)} + } + } + if v, ok := _c.mutation.ErrorReason(); ok { + if err := idempotencyrecord.ErrorReasonValidator(v); err != nil { + return &ValidationError{Name: "error_reason", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.error_reason": %w`, err)} + } + } + if _, ok := _c.mutation.ExpiresAt(); !ok { + return &ValidationError{Name: "expires_at", err: errors.New(`ent: missing required field "IdempotencyRecord.expires_at"`)} + } + return nil +} + +func (_c *IdempotencyRecordCreate) sqlSave(ctx context.Context) (*IdempotencyRecord, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *IdempotencyRecordCreate) createSpec() (*IdempotencyRecord, *sqlgraph.CreateSpec) { + var ( + _node = &IdempotencyRecord{config: _c.config} + _spec = sqlgraph.NewCreateSpec(idempotencyrecord.Table, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(idempotencyrecord.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(idempotencyrecord.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.Scope(); ok { + _spec.SetField(idempotencyrecord.FieldScope, field.TypeString, value) + _node.Scope = value + } + if value, ok := _c.mutation.IdempotencyKeyHash(); ok { + _spec.SetField(idempotencyrecord.FieldIdempotencyKeyHash, field.TypeString, value) + _node.IdempotencyKeyHash = value + } + if value, ok := _c.mutation.RequestFingerprint(); ok { + _spec.SetField(idempotencyrecord.FieldRequestFingerprint, field.TypeString, value) + _node.RequestFingerprint = value + } + if value, ok := _c.mutation.Status(); ok { + _spec.SetField(idempotencyrecord.FieldStatus, field.TypeString, value) + _node.Status = value + } + if value, ok := _c.mutation.ResponseStatus(); ok { + _spec.SetField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value) + _node.ResponseStatus = &value + } + if value, ok := _c.mutation.ResponseBody(); ok { + _spec.SetField(idempotencyrecord.FieldResponseBody, field.TypeString, value) + _node.ResponseBody = &value + } + if value, ok := _c.mutation.ErrorReason(); ok { + _spec.SetField(idempotencyrecord.FieldErrorReason, field.TypeString, value) + _node.ErrorReason = &value + } + if value, ok := _c.mutation.LockedUntil(); ok { + _spec.SetField(idempotencyrecord.FieldLockedUntil, field.TypeTime, value) + _node.LockedUntil = &value + } + if value, ok := _c.mutation.ExpiresAt(); ok { + _spec.SetField(idempotencyrecord.FieldExpiresAt, field.TypeTime, value) + _node.ExpiresAt = value + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.IdempotencyRecord.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.IdempotencyRecordUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *IdempotencyRecordCreate) OnConflict(opts ...sql.ConflictOption) *IdempotencyRecordUpsertOne { + _c.conflict = opts + return &IdempotencyRecordUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *IdempotencyRecordCreate) OnConflictColumns(columns ...string) *IdempotencyRecordUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &IdempotencyRecordUpsertOne{ + create: _c, + } +} + +type ( + // IdempotencyRecordUpsertOne is the builder for "upsert"-ing + // one IdempotencyRecord node. + IdempotencyRecordUpsertOne struct { + create *IdempotencyRecordCreate + } + + // IdempotencyRecordUpsert is the "OnConflict" setter. + IdempotencyRecordUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdempotencyRecordUpsert) SetUpdatedAt(v time.Time) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateUpdatedAt() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldUpdatedAt) + return u +} + +// SetScope sets the "scope" field. +func (u *IdempotencyRecordUpsert) SetScope(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldScope, v) + return u +} + +// UpdateScope sets the "scope" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateScope() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldScope) + return u +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (u *IdempotencyRecordUpsert) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldIdempotencyKeyHash, v) + return u +} + +// UpdateIdempotencyKeyHash sets the "idempotency_key_hash" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateIdempotencyKeyHash() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldIdempotencyKeyHash) + return u +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (u *IdempotencyRecordUpsert) SetRequestFingerprint(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldRequestFingerprint, v) + return u +} + +// UpdateRequestFingerprint sets the "request_fingerprint" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateRequestFingerprint() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldRequestFingerprint) + return u +} + +// SetStatus sets the "status" field. +func (u *IdempotencyRecordUpsert) SetStatus(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldStatus, v) + return u +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateStatus() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldStatus) + return u +} + +// SetResponseStatus sets the "response_status" field. +func (u *IdempotencyRecordUpsert) SetResponseStatus(v int) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldResponseStatus, v) + return u +} + +// UpdateResponseStatus sets the "response_status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateResponseStatus() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldResponseStatus) + return u +} + +// AddResponseStatus adds v to the "response_status" field. +func (u *IdempotencyRecordUpsert) AddResponseStatus(v int) *IdempotencyRecordUpsert { + u.Add(idempotencyrecord.FieldResponseStatus, v) + return u +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (u *IdempotencyRecordUpsert) ClearResponseStatus() *IdempotencyRecordUpsert { + u.SetNull(idempotencyrecord.FieldResponseStatus) + return u +} + +// SetResponseBody sets the "response_body" field. +func (u *IdempotencyRecordUpsert) SetResponseBody(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldResponseBody, v) + return u +} + +// UpdateResponseBody sets the "response_body" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateResponseBody() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldResponseBody) + return u +} + +// ClearResponseBody clears the value of the "response_body" field. +func (u *IdempotencyRecordUpsert) ClearResponseBody() *IdempotencyRecordUpsert { + u.SetNull(idempotencyrecord.FieldResponseBody) + return u +} + +// SetErrorReason sets the "error_reason" field. +func (u *IdempotencyRecordUpsert) SetErrorReason(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldErrorReason, v) + return u +} + +// UpdateErrorReason sets the "error_reason" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateErrorReason() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldErrorReason) + return u +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (u *IdempotencyRecordUpsert) ClearErrorReason() *IdempotencyRecordUpsert { + u.SetNull(idempotencyrecord.FieldErrorReason) + return u +} + +// SetLockedUntil sets the "locked_until" field. +func (u *IdempotencyRecordUpsert) SetLockedUntil(v time.Time) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldLockedUntil, v) + return u +} + +// UpdateLockedUntil sets the "locked_until" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateLockedUntil() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldLockedUntil) + return u +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (u *IdempotencyRecordUpsert) ClearLockedUntil() *IdempotencyRecordUpsert { + u.SetNull(idempotencyrecord.FieldLockedUntil) + return u +} + +// SetExpiresAt sets the "expires_at" field. +func (u *IdempotencyRecordUpsert) SetExpiresAt(v time.Time) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldExpiresAt, v) + return u +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateExpiresAt() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldExpiresAt) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *IdempotencyRecordUpsertOne) UpdateNewValues() *IdempotencyRecordUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(idempotencyrecord.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *IdempotencyRecordUpsertOne) Ignore() *IdempotencyRecordUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *IdempotencyRecordUpsertOne) DoNothing() *IdempotencyRecordUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the IdempotencyRecordCreate.OnConflict +// documentation for more info. +func (u *IdempotencyRecordUpsertOne) Update(set func(*IdempotencyRecordUpsert)) *IdempotencyRecordUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&IdempotencyRecordUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdempotencyRecordUpsertOne) SetUpdatedAt(v time.Time) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateUpdatedAt() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetScope sets the "scope" field. +func (u *IdempotencyRecordUpsertOne) SetScope(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetScope(v) + }) +} + +// UpdateScope sets the "scope" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateScope() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateScope() + }) +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (u *IdempotencyRecordUpsertOne) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetIdempotencyKeyHash(v) + }) +} + +// UpdateIdempotencyKeyHash sets the "idempotency_key_hash" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateIdempotencyKeyHash() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateIdempotencyKeyHash() + }) +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (u *IdempotencyRecordUpsertOne) SetRequestFingerprint(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetRequestFingerprint(v) + }) +} + +// UpdateRequestFingerprint sets the "request_fingerprint" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateRequestFingerprint() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateRequestFingerprint() + }) +} + +// SetStatus sets the "status" field. +func (u *IdempotencyRecordUpsertOne) SetStatus(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateStatus() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateStatus() + }) +} + +// SetResponseStatus sets the "response_status" field. +func (u *IdempotencyRecordUpsertOne) SetResponseStatus(v int) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetResponseStatus(v) + }) +} + +// AddResponseStatus adds v to the "response_status" field. +func (u *IdempotencyRecordUpsertOne) AddResponseStatus(v int) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.AddResponseStatus(v) + }) +} + +// UpdateResponseStatus sets the "response_status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateResponseStatus() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateResponseStatus() + }) +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (u *IdempotencyRecordUpsertOne) ClearResponseStatus() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearResponseStatus() + }) +} + +// SetResponseBody sets the "response_body" field. +func (u *IdempotencyRecordUpsertOne) SetResponseBody(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetResponseBody(v) + }) +} + +// UpdateResponseBody sets the "response_body" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateResponseBody() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateResponseBody() + }) +} + +// ClearResponseBody clears the value of the "response_body" field. +func (u *IdempotencyRecordUpsertOne) ClearResponseBody() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearResponseBody() + }) +} + +// SetErrorReason sets the "error_reason" field. +func (u *IdempotencyRecordUpsertOne) SetErrorReason(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetErrorReason(v) + }) +} + +// UpdateErrorReason sets the "error_reason" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateErrorReason() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateErrorReason() + }) +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (u *IdempotencyRecordUpsertOne) ClearErrorReason() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearErrorReason() + }) +} + +// SetLockedUntil sets the "locked_until" field. +func (u *IdempotencyRecordUpsertOne) SetLockedUntil(v time.Time) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetLockedUntil(v) + }) +} + +// UpdateLockedUntil sets the "locked_until" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateLockedUntil() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateLockedUntil() + }) +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (u *IdempotencyRecordUpsertOne) ClearLockedUntil() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearLockedUntil() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *IdempotencyRecordUpsertOne) SetExpiresAt(v time.Time) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateExpiresAt() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateExpiresAt() + }) +} + +// Exec executes the query. +func (u *IdempotencyRecordUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for IdempotencyRecordCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *IdempotencyRecordUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *IdempotencyRecordUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *IdempotencyRecordUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// IdempotencyRecordCreateBulk is the builder for creating many IdempotencyRecord entities in bulk. +type IdempotencyRecordCreateBulk struct { + config + err error + builders []*IdempotencyRecordCreate + conflict []sql.ConflictOption +} + +// Save creates the IdempotencyRecord entities in the database. +func (_c *IdempotencyRecordCreateBulk) Save(ctx context.Context) ([]*IdempotencyRecord, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*IdempotencyRecord, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*IdempotencyRecordMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *IdempotencyRecordCreateBulk) SaveX(ctx context.Context) []*IdempotencyRecord { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *IdempotencyRecordCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *IdempotencyRecordCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.IdempotencyRecord.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.IdempotencyRecordUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *IdempotencyRecordCreateBulk) OnConflict(opts ...sql.ConflictOption) *IdempotencyRecordUpsertBulk { + _c.conflict = opts + return &IdempotencyRecordUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *IdempotencyRecordCreateBulk) OnConflictColumns(columns ...string) *IdempotencyRecordUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &IdempotencyRecordUpsertBulk{ + create: _c, + } +} + +// IdempotencyRecordUpsertBulk is the builder for "upsert"-ing +// a bulk of IdempotencyRecord nodes. +type IdempotencyRecordUpsertBulk struct { + create *IdempotencyRecordCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *IdempotencyRecordUpsertBulk) UpdateNewValues() *IdempotencyRecordUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(idempotencyrecord.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *IdempotencyRecordUpsertBulk) Ignore() *IdempotencyRecordUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *IdempotencyRecordUpsertBulk) DoNothing() *IdempotencyRecordUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the IdempotencyRecordCreateBulk.OnConflict +// documentation for more info. +func (u *IdempotencyRecordUpsertBulk) Update(set func(*IdempotencyRecordUpsert)) *IdempotencyRecordUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&IdempotencyRecordUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdempotencyRecordUpsertBulk) SetUpdatedAt(v time.Time) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateUpdatedAt() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetScope sets the "scope" field. +func (u *IdempotencyRecordUpsertBulk) SetScope(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetScope(v) + }) +} + +// UpdateScope sets the "scope" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateScope() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateScope() + }) +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (u *IdempotencyRecordUpsertBulk) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetIdempotencyKeyHash(v) + }) +} + +// UpdateIdempotencyKeyHash sets the "idempotency_key_hash" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateIdempotencyKeyHash() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateIdempotencyKeyHash() + }) +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (u *IdempotencyRecordUpsertBulk) SetRequestFingerprint(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetRequestFingerprint(v) + }) +} + +// UpdateRequestFingerprint sets the "request_fingerprint" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateRequestFingerprint() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateRequestFingerprint() + }) +} + +// SetStatus sets the "status" field. +func (u *IdempotencyRecordUpsertBulk) SetStatus(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateStatus() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateStatus() + }) +} + +// SetResponseStatus sets the "response_status" field. +func (u *IdempotencyRecordUpsertBulk) SetResponseStatus(v int) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetResponseStatus(v) + }) +} + +// AddResponseStatus adds v to the "response_status" field. +func (u *IdempotencyRecordUpsertBulk) AddResponseStatus(v int) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.AddResponseStatus(v) + }) +} + +// UpdateResponseStatus sets the "response_status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateResponseStatus() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateResponseStatus() + }) +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (u *IdempotencyRecordUpsertBulk) ClearResponseStatus() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearResponseStatus() + }) +} + +// SetResponseBody sets the "response_body" field. +func (u *IdempotencyRecordUpsertBulk) SetResponseBody(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetResponseBody(v) + }) +} + +// UpdateResponseBody sets the "response_body" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateResponseBody() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateResponseBody() + }) +} + +// ClearResponseBody clears the value of the "response_body" field. +func (u *IdempotencyRecordUpsertBulk) ClearResponseBody() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearResponseBody() + }) +} + +// SetErrorReason sets the "error_reason" field. +func (u *IdempotencyRecordUpsertBulk) SetErrorReason(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetErrorReason(v) + }) +} + +// UpdateErrorReason sets the "error_reason" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateErrorReason() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateErrorReason() + }) +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (u *IdempotencyRecordUpsertBulk) ClearErrorReason() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearErrorReason() + }) +} + +// SetLockedUntil sets the "locked_until" field. +func (u *IdempotencyRecordUpsertBulk) SetLockedUntil(v time.Time) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetLockedUntil(v) + }) +} + +// UpdateLockedUntil sets the "locked_until" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateLockedUntil() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateLockedUntil() + }) +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (u *IdempotencyRecordUpsertBulk) ClearLockedUntil() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearLockedUntil() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *IdempotencyRecordUpsertBulk) SetExpiresAt(v time.Time) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateExpiresAt() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateExpiresAt() + }) +} + +// Exec executes the query. +func (u *IdempotencyRecordUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the IdempotencyRecordCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for IdempotencyRecordCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *IdempotencyRecordUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/idempotencyrecord_delete.go b/backend/ent/idempotencyrecord_delete.go new file mode 100644 index 00000000..f5c87559 --- /dev/null +++ b/backend/ent/idempotencyrecord_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdempotencyRecordDelete is the builder for deleting a IdempotencyRecord entity. +type IdempotencyRecordDelete struct { + config + hooks []Hook + mutation *IdempotencyRecordMutation +} + +// Where appends a list predicates to the IdempotencyRecordDelete builder. +func (_d *IdempotencyRecordDelete) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *IdempotencyRecordDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *IdempotencyRecordDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *IdempotencyRecordDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(idempotencyrecord.Table, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// IdempotencyRecordDeleteOne is the builder for deleting a single IdempotencyRecord entity. +type IdempotencyRecordDeleteOne struct { + _d *IdempotencyRecordDelete +} + +// Where appends a list predicates to the IdempotencyRecordDelete builder. +func (_d *IdempotencyRecordDeleteOne) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *IdempotencyRecordDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{idempotencyrecord.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *IdempotencyRecordDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/idempotencyrecord_query.go b/backend/ent/idempotencyrecord_query.go new file mode 100644 index 00000000..fbba4dfa --- /dev/null +++ b/backend/ent/idempotencyrecord_query.go @@ -0,0 +1,564 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdempotencyRecordQuery is the builder for querying IdempotencyRecord entities. +type IdempotencyRecordQuery struct { + config + ctx *QueryContext + order []idempotencyrecord.OrderOption + inters []Interceptor + predicates []predicate.IdempotencyRecord + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the IdempotencyRecordQuery builder. +func (_q *IdempotencyRecordQuery) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *IdempotencyRecordQuery) Limit(limit int) *IdempotencyRecordQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *IdempotencyRecordQuery) Offset(offset int) *IdempotencyRecordQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *IdempotencyRecordQuery) Unique(unique bool) *IdempotencyRecordQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *IdempotencyRecordQuery) Order(o ...idempotencyrecord.OrderOption) *IdempotencyRecordQuery { + _q.order = append(_q.order, o...) + return _q +} + +// First returns the first IdempotencyRecord entity from the query. +// Returns a *NotFoundError when no IdempotencyRecord was found. +func (_q *IdempotencyRecordQuery) First(ctx context.Context) (*IdempotencyRecord, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{idempotencyrecord.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) FirstX(ctx context.Context) *IdempotencyRecord { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first IdempotencyRecord ID from the query. +// Returns a *NotFoundError when no IdempotencyRecord ID was found. +func (_q *IdempotencyRecordQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{idempotencyrecord.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single IdempotencyRecord entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one IdempotencyRecord entity is found. +// Returns a *NotFoundError when no IdempotencyRecord entities are found. +func (_q *IdempotencyRecordQuery) Only(ctx context.Context) (*IdempotencyRecord, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{idempotencyrecord.Label} + default: + return nil, &NotSingularError{idempotencyrecord.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) OnlyX(ctx context.Context) *IdempotencyRecord { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only IdempotencyRecord ID in the query. +// Returns a *NotSingularError when more than one IdempotencyRecord ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *IdempotencyRecordQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{idempotencyrecord.Label} + default: + err = &NotSingularError{idempotencyrecord.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of IdempotencyRecords. +func (_q *IdempotencyRecordQuery) All(ctx context.Context) ([]*IdempotencyRecord, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*IdempotencyRecord, *IdempotencyRecordQuery]() + return withInterceptors[[]*IdempotencyRecord](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) AllX(ctx context.Context) []*IdempotencyRecord { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of IdempotencyRecord IDs. +func (_q *IdempotencyRecordQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(idempotencyrecord.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *IdempotencyRecordQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*IdempotencyRecordQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *IdempotencyRecordQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the IdempotencyRecordQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *IdempotencyRecordQuery) Clone() *IdempotencyRecordQuery { + if _q == nil { + return nil + } + return &IdempotencyRecordQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]idempotencyrecord.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.IdempotencyRecord{}, _q.predicates...), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.IdempotencyRecord.Query(). +// GroupBy(idempotencyrecord.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *IdempotencyRecordQuery) GroupBy(field string, fields ...string) *IdempotencyRecordGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &IdempotencyRecordGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = idempotencyrecord.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.IdempotencyRecord.Query(). +// Select(idempotencyrecord.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *IdempotencyRecordQuery) Select(fields ...string) *IdempotencyRecordSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &IdempotencyRecordSelect{IdempotencyRecordQuery: _q} + sbuild.label = idempotencyrecord.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a IdempotencyRecordSelect configured with the given aggregations. +func (_q *IdempotencyRecordQuery) Aggregate(fns ...AggregateFunc) *IdempotencyRecordSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *IdempotencyRecordQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !idempotencyrecord.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *IdempotencyRecordQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*IdempotencyRecord, error) { + var ( + nodes = []*IdempotencyRecord{} + _spec = _q.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*IdempotencyRecord).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &IdempotencyRecord{config: _q.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (_q *IdempotencyRecordQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *IdempotencyRecordQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(idempotencyrecord.Table, idempotencyrecord.Columns, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, idempotencyrecord.FieldID) + for i := range fields { + if fields[i] != idempotencyrecord.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *IdempotencyRecordQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(idempotencyrecord.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = idempotencyrecord.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *IdempotencyRecordQuery) ForUpdate(opts ...sql.LockOption) *IdempotencyRecordQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *IdempotencyRecordQuery) ForShare(opts ...sql.LockOption) *IdempotencyRecordQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// IdempotencyRecordGroupBy is the group-by builder for IdempotencyRecord entities. +type IdempotencyRecordGroupBy struct { + selector + build *IdempotencyRecordQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *IdempotencyRecordGroupBy) Aggregate(fns ...AggregateFunc) *IdempotencyRecordGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *IdempotencyRecordGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*IdempotencyRecordQuery, *IdempotencyRecordGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *IdempotencyRecordGroupBy) sqlScan(ctx context.Context, root *IdempotencyRecordQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// IdempotencyRecordSelect is the builder for selecting fields of IdempotencyRecord entities. +type IdempotencyRecordSelect struct { + *IdempotencyRecordQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *IdempotencyRecordSelect) Aggregate(fns ...AggregateFunc) *IdempotencyRecordSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *IdempotencyRecordSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*IdempotencyRecordQuery, *IdempotencyRecordSelect](ctx, _s.IdempotencyRecordQuery, _s, _s.inters, v) +} + +func (_s *IdempotencyRecordSelect) sqlScan(ctx context.Context, root *IdempotencyRecordQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/idempotencyrecord_update.go b/backend/ent/idempotencyrecord_update.go new file mode 100644 index 00000000..f839e5c0 --- /dev/null +++ b/backend/ent/idempotencyrecord_update.go @@ -0,0 +1,676 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdempotencyRecordUpdate is the builder for updating IdempotencyRecord entities. +type IdempotencyRecordUpdate struct { + config + hooks []Hook + mutation *IdempotencyRecordMutation +} + +// Where appends a list predicates to the IdempotencyRecordUpdate builder. +func (_u *IdempotencyRecordUpdate) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *IdempotencyRecordUpdate) SetUpdatedAt(v time.Time) *IdempotencyRecordUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetScope sets the "scope" field. +func (_u *IdempotencyRecordUpdate) SetScope(v string) *IdempotencyRecordUpdate { + _u.mutation.SetScope(v) + return _u +} + +// SetNillableScope sets the "scope" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableScope(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetScope(*v) + } + return _u +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (_u *IdempotencyRecordUpdate) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpdate { + _u.mutation.SetIdempotencyKeyHash(v) + return _u +} + +// SetNillableIdempotencyKeyHash sets the "idempotency_key_hash" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableIdempotencyKeyHash(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetIdempotencyKeyHash(*v) + } + return _u +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (_u *IdempotencyRecordUpdate) SetRequestFingerprint(v string) *IdempotencyRecordUpdate { + _u.mutation.SetRequestFingerprint(v) + return _u +} + +// SetNillableRequestFingerprint sets the "request_fingerprint" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableRequestFingerprint(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetRequestFingerprint(*v) + } + return _u +} + +// SetStatus sets the "status" field. +func (_u *IdempotencyRecordUpdate) SetStatus(v string) *IdempotencyRecordUpdate { + _u.mutation.SetStatus(v) + return _u +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableStatus(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetStatus(*v) + } + return _u +} + +// SetResponseStatus sets the "response_status" field. +func (_u *IdempotencyRecordUpdate) SetResponseStatus(v int) *IdempotencyRecordUpdate { + _u.mutation.ResetResponseStatus() + _u.mutation.SetResponseStatus(v) + return _u +} + +// SetNillableResponseStatus sets the "response_status" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableResponseStatus(v *int) *IdempotencyRecordUpdate { + if v != nil { + _u.SetResponseStatus(*v) + } + return _u +} + +// AddResponseStatus adds value to the "response_status" field. +func (_u *IdempotencyRecordUpdate) AddResponseStatus(v int) *IdempotencyRecordUpdate { + _u.mutation.AddResponseStatus(v) + return _u +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (_u *IdempotencyRecordUpdate) ClearResponseStatus() *IdempotencyRecordUpdate { + _u.mutation.ClearResponseStatus() + return _u +} + +// SetResponseBody sets the "response_body" field. +func (_u *IdempotencyRecordUpdate) SetResponseBody(v string) *IdempotencyRecordUpdate { + _u.mutation.SetResponseBody(v) + return _u +} + +// SetNillableResponseBody sets the "response_body" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableResponseBody(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetResponseBody(*v) + } + return _u +} + +// ClearResponseBody clears the value of the "response_body" field. +func (_u *IdempotencyRecordUpdate) ClearResponseBody() *IdempotencyRecordUpdate { + _u.mutation.ClearResponseBody() + return _u +} + +// SetErrorReason sets the "error_reason" field. +func (_u *IdempotencyRecordUpdate) SetErrorReason(v string) *IdempotencyRecordUpdate { + _u.mutation.SetErrorReason(v) + return _u +} + +// SetNillableErrorReason sets the "error_reason" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableErrorReason(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetErrorReason(*v) + } + return _u +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (_u *IdempotencyRecordUpdate) ClearErrorReason() *IdempotencyRecordUpdate { + _u.mutation.ClearErrorReason() + return _u +} + +// SetLockedUntil sets the "locked_until" field. +func (_u *IdempotencyRecordUpdate) SetLockedUntil(v time.Time) *IdempotencyRecordUpdate { + _u.mutation.SetLockedUntil(v) + return _u +} + +// SetNillableLockedUntil sets the "locked_until" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableLockedUntil(v *time.Time) *IdempotencyRecordUpdate { + if v != nil { + _u.SetLockedUntil(*v) + } + return _u +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (_u *IdempotencyRecordUpdate) ClearLockedUntil() *IdempotencyRecordUpdate { + _u.mutation.ClearLockedUntil() + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *IdempotencyRecordUpdate) SetExpiresAt(v time.Time) *IdempotencyRecordUpdate { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableExpiresAt(v *time.Time) *IdempotencyRecordUpdate { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// Mutation returns the IdempotencyRecordMutation object of the builder. +func (_u *IdempotencyRecordUpdate) Mutation() *IdempotencyRecordMutation { + return _u.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *IdempotencyRecordUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *IdempotencyRecordUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *IdempotencyRecordUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *IdempotencyRecordUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *IdempotencyRecordUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := idempotencyrecord.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *IdempotencyRecordUpdate) check() error { + if v, ok := _u.mutation.Scope(); ok { + if err := idempotencyrecord.ScopeValidator(v); err != nil { + return &ValidationError{Name: "scope", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.scope": %w`, err)} + } + } + if v, ok := _u.mutation.IdempotencyKeyHash(); ok { + if err := idempotencyrecord.IdempotencyKeyHashValidator(v); err != nil { + return &ValidationError{Name: "idempotency_key_hash", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.idempotency_key_hash": %w`, err)} + } + } + if v, ok := _u.mutation.RequestFingerprint(); ok { + if err := idempotencyrecord.RequestFingerprintValidator(v); err != nil { + return &ValidationError{Name: "request_fingerprint", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.request_fingerprint": %w`, err)} + } + } + if v, ok := _u.mutation.Status(); ok { + if err := idempotencyrecord.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.status": %w`, err)} + } + } + if v, ok := _u.mutation.ErrorReason(); ok { + if err := idempotencyrecord.ErrorReasonValidator(v); err != nil { + return &ValidationError{Name: "error_reason", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.error_reason": %w`, err)} + } + } + return nil +} + +func (_u *IdempotencyRecordUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(idempotencyrecord.Table, idempotencyrecord.Columns, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(idempotencyrecord.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Scope(); ok { + _spec.SetField(idempotencyrecord.FieldScope, field.TypeString, value) + } + if value, ok := _u.mutation.IdempotencyKeyHash(); ok { + _spec.SetField(idempotencyrecord.FieldIdempotencyKeyHash, field.TypeString, value) + } + if value, ok := _u.mutation.RequestFingerprint(); ok { + _spec.SetField(idempotencyrecord.FieldRequestFingerprint, field.TypeString, value) + } + if value, ok := _u.mutation.Status(); ok { + _spec.SetField(idempotencyrecord.FieldStatus, field.TypeString, value) + } + if value, ok := _u.mutation.ResponseStatus(); ok { + _spec.SetField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedResponseStatus(); ok { + _spec.AddField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value) + } + if _u.mutation.ResponseStatusCleared() { + _spec.ClearField(idempotencyrecord.FieldResponseStatus, field.TypeInt) + } + if value, ok := _u.mutation.ResponseBody(); ok { + _spec.SetField(idempotencyrecord.FieldResponseBody, field.TypeString, value) + } + if _u.mutation.ResponseBodyCleared() { + _spec.ClearField(idempotencyrecord.FieldResponseBody, field.TypeString) + } + if value, ok := _u.mutation.ErrorReason(); ok { + _spec.SetField(idempotencyrecord.FieldErrorReason, field.TypeString, value) + } + if _u.mutation.ErrorReasonCleared() { + _spec.ClearField(idempotencyrecord.FieldErrorReason, field.TypeString) + } + if value, ok := _u.mutation.LockedUntil(); ok { + _spec.SetField(idempotencyrecord.FieldLockedUntil, field.TypeTime, value) + } + if _u.mutation.LockedUntilCleared() { + _spec.ClearField(idempotencyrecord.FieldLockedUntil, field.TypeTime) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(idempotencyrecord.FieldExpiresAt, field.TypeTime, value) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{idempotencyrecord.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// IdempotencyRecordUpdateOne is the builder for updating a single IdempotencyRecord entity. +type IdempotencyRecordUpdateOne struct { + config + fields []string + hooks []Hook + mutation *IdempotencyRecordMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *IdempotencyRecordUpdateOne) SetUpdatedAt(v time.Time) *IdempotencyRecordUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetScope sets the "scope" field. +func (_u *IdempotencyRecordUpdateOne) SetScope(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetScope(v) + return _u +} + +// SetNillableScope sets the "scope" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableScope(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetScope(*v) + } + return _u +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (_u *IdempotencyRecordUpdateOne) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetIdempotencyKeyHash(v) + return _u +} + +// SetNillableIdempotencyKeyHash sets the "idempotency_key_hash" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableIdempotencyKeyHash(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetIdempotencyKeyHash(*v) + } + return _u +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (_u *IdempotencyRecordUpdateOne) SetRequestFingerprint(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetRequestFingerprint(v) + return _u +} + +// SetNillableRequestFingerprint sets the "request_fingerprint" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableRequestFingerprint(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetRequestFingerprint(*v) + } + return _u +} + +// SetStatus sets the "status" field. +func (_u *IdempotencyRecordUpdateOne) SetStatus(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetStatus(v) + return _u +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableStatus(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetStatus(*v) + } + return _u +} + +// SetResponseStatus sets the "response_status" field. +func (_u *IdempotencyRecordUpdateOne) SetResponseStatus(v int) *IdempotencyRecordUpdateOne { + _u.mutation.ResetResponseStatus() + _u.mutation.SetResponseStatus(v) + return _u +} + +// SetNillableResponseStatus sets the "response_status" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableResponseStatus(v *int) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetResponseStatus(*v) + } + return _u +} + +// AddResponseStatus adds value to the "response_status" field. +func (_u *IdempotencyRecordUpdateOne) AddResponseStatus(v int) *IdempotencyRecordUpdateOne { + _u.mutation.AddResponseStatus(v) + return _u +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (_u *IdempotencyRecordUpdateOne) ClearResponseStatus() *IdempotencyRecordUpdateOne { + _u.mutation.ClearResponseStatus() + return _u +} + +// SetResponseBody sets the "response_body" field. +func (_u *IdempotencyRecordUpdateOne) SetResponseBody(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetResponseBody(v) + return _u +} + +// SetNillableResponseBody sets the "response_body" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableResponseBody(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetResponseBody(*v) + } + return _u +} + +// ClearResponseBody clears the value of the "response_body" field. +func (_u *IdempotencyRecordUpdateOne) ClearResponseBody() *IdempotencyRecordUpdateOne { + _u.mutation.ClearResponseBody() + return _u +} + +// SetErrorReason sets the "error_reason" field. +func (_u *IdempotencyRecordUpdateOne) SetErrorReason(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetErrorReason(v) + return _u +} + +// SetNillableErrorReason sets the "error_reason" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableErrorReason(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetErrorReason(*v) + } + return _u +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (_u *IdempotencyRecordUpdateOne) ClearErrorReason() *IdempotencyRecordUpdateOne { + _u.mutation.ClearErrorReason() + return _u +} + +// SetLockedUntil sets the "locked_until" field. +func (_u *IdempotencyRecordUpdateOne) SetLockedUntil(v time.Time) *IdempotencyRecordUpdateOne { + _u.mutation.SetLockedUntil(v) + return _u +} + +// SetNillableLockedUntil sets the "locked_until" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableLockedUntil(v *time.Time) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetLockedUntil(*v) + } + return _u +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (_u *IdempotencyRecordUpdateOne) ClearLockedUntil() *IdempotencyRecordUpdateOne { + _u.mutation.ClearLockedUntil() + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *IdempotencyRecordUpdateOne) SetExpiresAt(v time.Time) *IdempotencyRecordUpdateOne { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableExpiresAt(v *time.Time) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// Mutation returns the IdempotencyRecordMutation object of the builder. +func (_u *IdempotencyRecordUpdateOne) Mutation() *IdempotencyRecordMutation { + return _u.mutation +} + +// Where appends a list predicates to the IdempotencyRecordUpdate builder. +func (_u *IdempotencyRecordUpdateOne) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *IdempotencyRecordUpdateOne) Select(field string, fields ...string) *IdempotencyRecordUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated IdempotencyRecord entity. +func (_u *IdempotencyRecordUpdateOne) Save(ctx context.Context) (*IdempotencyRecord, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *IdempotencyRecordUpdateOne) SaveX(ctx context.Context) *IdempotencyRecord { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *IdempotencyRecordUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *IdempotencyRecordUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *IdempotencyRecordUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := idempotencyrecord.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *IdempotencyRecordUpdateOne) check() error { + if v, ok := _u.mutation.Scope(); ok { + if err := idempotencyrecord.ScopeValidator(v); err != nil { + return &ValidationError{Name: "scope", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.scope": %w`, err)} + } + } + if v, ok := _u.mutation.IdempotencyKeyHash(); ok { + if err := idempotencyrecord.IdempotencyKeyHashValidator(v); err != nil { + return &ValidationError{Name: "idempotency_key_hash", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.idempotency_key_hash": %w`, err)} + } + } + if v, ok := _u.mutation.RequestFingerprint(); ok { + if err := idempotencyrecord.RequestFingerprintValidator(v); err != nil { + return &ValidationError{Name: "request_fingerprint", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.request_fingerprint": %w`, err)} + } + } + if v, ok := _u.mutation.Status(); ok { + if err := idempotencyrecord.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.status": %w`, err)} + } + } + if v, ok := _u.mutation.ErrorReason(); ok { + if err := idempotencyrecord.ErrorReasonValidator(v); err != nil { + return &ValidationError{Name: "error_reason", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.error_reason": %w`, err)} + } + } + return nil +} + +func (_u *IdempotencyRecordUpdateOne) sqlSave(ctx context.Context) (_node *IdempotencyRecord, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(idempotencyrecord.Table, idempotencyrecord.Columns, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "IdempotencyRecord.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, idempotencyrecord.FieldID) + for _, f := range fields { + if !idempotencyrecord.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != idempotencyrecord.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(idempotencyrecord.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Scope(); ok { + _spec.SetField(idempotencyrecord.FieldScope, field.TypeString, value) + } + if value, ok := _u.mutation.IdempotencyKeyHash(); ok { + _spec.SetField(idempotencyrecord.FieldIdempotencyKeyHash, field.TypeString, value) + } + if value, ok := _u.mutation.RequestFingerprint(); ok { + _spec.SetField(idempotencyrecord.FieldRequestFingerprint, field.TypeString, value) + } + if value, ok := _u.mutation.Status(); ok { + _spec.SetField(idempotencyrecord.FieldStatus, field.TypeString, value) + } + if value, ok := _u.mutation.ResponseStatus(); ok { + _spec.SetField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedResponseStatus(); ok { + _spec.AddField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value) + } + if _u.mutation.ResponseStatusCleared() { + _spec.ClearField(idempotencyrecord.FieldResponseStatus, field.TypeInt) + } + if value, ok := _u.mutation.ResponseBody(); ok { + _spec.SetField(idempotencyrecord.FieldResponseBody, field.TypeString, value) + } + if _u.mutation.ResponseBodyCleared() { + _spec.ClearField(idempotencyrecord.FieldResponseBody, field.TypeString) + } + if value, ok := _u.mutation.ErrorReason(); ok { + _spec.SetField(idempotencyrecord.FieldErrorReason, field.TypeString, value) + } + if _u.mutation.ErrorReasonCleared() { + _spec.ClearField(idempotencyrecord.FieldErrorReason, field.TypeString) + } + if value, ok := _u.mutation.LockedUntil(); ok { + _spec.SetField(idempotencyrecord.FieldLockedUntil, field.TypeTime, value) + } + if _u.mutation.LockedUntilCleared() { + _spec.ClearField(idempotencyrecord.FieldLockedUntil, field.TypeTime) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(idempotencyrecord.FieldExpiresAt, field.TypeTime, value) + } + _node = &IdempotencyRecord{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{idempotencyrecord.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index 8ee42db3..e7746402 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -15,11 +15,13 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/setting" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" @@ -275,6 +277,33 @@ func (f TraverseGroup) Traverse(ctx context.Context, q ent.Query) error { return fmt.Errorf("unexpected query type %T. expect *ent.GroupQuery", q) } +// The IdempotencyRecordFunc type is an adapter to allow the use of ordinary function as a Querier. +type IdempotencyRecordFunc func(context.Context, *ent.IdempotencyRecordQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f IdempotencyRecordFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.IdempotencyRecordQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q) +} + +// The TraverseIdempotencyRecord type is an adapter to allow the use of ordinary function as Traverser. +type TraverseIdempotencyRecord func(context.Context, *ent.IdempotencyRecordQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseIdempotencyRecord) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseIdempotencyRecord) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.IdempotencyRecordQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q) +} + // The PromoCodeFunc type is an adapter to allow the use of ordinary function as a Querier. type PromoCodeFunc func(context.Context, *ent.PromoCodeQuery) (ent.Value, error) @@ -383,6 +412,33 @@ func (f TraverseRedeemCode) Traverse(ctx context.Context, q ent.Query) error { return fmt.Errorf("unexpected query type %T. expect *ent.RedeemCodeQuery", q) } +// The SecuritySecretFunc type is an adapter to allow the use of ordinary function as a Querier. +type SecuritySecretFunc func(context.Context, *ent.SecuritySecretQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f SecuritySecretFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.SecuritySecretQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.SecuritySecretQuery", q) +} + +// The TraverseSecuritySecret type is an adapter to allow the use of ordinary function as Traverser. +type TraverseSecuritySecret func(context.Context, *ent.SecuritySecretQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseSecuritySecret) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseSecuritySecret) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.SecuritySecretQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.SecuritySecretQuery", q) +} + // The SettingFunc type is an adapter to allow the use of ordinary function as a Querier. type SettingFunc func(context.Context, *ent.SettingQuery) (ent.Value, error) @@ -616,6 +672,8 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil case *ent.GroupQuery: return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil + case *ent.IdempotencyRecordQuery: + return &query[*ent.IdempotencyRecordQuery, predicate.IdempotencyRecord, idempotencyrecord.OrderOption]{typ: ent.TypeIdempotencyRecord, tq: q}, nil case *ent.PromoCodeQuery: return &query[*ent.PromoCodeQuery, predicate.PromoCode, promocode.OrderOption]{typ: ent.TypePromoCode, tq: q}, nil case *ent.PromoCodeUsageQuery: @@ -624,6 +682,8 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.ProxyQuery, predicate.Proxy, proxy.OrderOption]{typ: ent.TypeProxy, tq: q}, nil case *ent.RedeemCodeQuery: return &query[*ent.RedeemCodeQuery, predicate.RedeemCode, redeemcode.OrderOption]{typ: ent.TypeRedeemCode, tq: q}, nil + case *ent.SecuritySecretQuery: + return &query[*ent.SecuritySecretQuery, predicate.SecuritySecret, securitysecret.OrderOption]{typ: ent.TypeSecuritySecret, tq: q}, nil case *ent.SettingQuery: return &query[*ent.SettingQuery, predicate.Setting, setting.OrderOption]{typ: ent.TypeSetting, tq: q}, nil case *ent.UsageCleanupTaskQuery: diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index ef5f1e04..85e94072 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -18,11 +18,21 @@ var ( {Name: "key", Type: field.TypeString, Unique: true, Size: 128}, {Name: "name", Type: field.TypeString, Size: 100}, {Name: "status", Type: field.TypeString, Size: 20, Default: "active"}, + {Name: "last_used_at", Type: field.TypeTime, Nullable: true}, {Name: "ip_whitelist", Type: field.TypeJSON, Nullable: true}, {Name: "ip_blacklist", Type: field.TypeJSON, Nullable: true}, {Name: "quota", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "quota_used", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "expires_at", Type: field.TypeTime, Nullable: true}, + {Name: "rate_limit_5h", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "rate_limit_1d", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "rate_limit_7d", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "usage_5h", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "usage_1d", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "usage_7d", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "window_5h_start", Type: field.TypeTime, Nullable: true}, + {Name: "window_1d_start", Type: field.TypeTime, Nullable: true}, + {Name: "window_7d_start", Type: field.TypeTime, Nullable: true}, {Name: "group_id", Type: field.TypeInt64, Nullable: true}, {Name: "user_id", Type: field.TypeInt64}, } @@ -34,13 +44,13 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "api_keys_groups_api_keys", - Columns: []*schema.Column{APIKeysColumns[12]}, + Columns: []*schema.Column{APIKeysColumns[22]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "api_keys_users_api_keys", - Columns: []*schema.Column{APIKeysColumns[13]}, + Columns: []*schema.Column{APIKeysColumns[23]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, @@ -49,12 +59,12 @@ var ( { Name: "apikey_user_id", Unique: false, - Columns: []*schema.Column{APIKeysColumns[13]}, + Columns: []*schema.Column{APIKeysColumns[23]}, }, { Name: "apikey_group_id", Unique: false, - Columns: []*schema.Column{APIKeysColumns[12]}, + Columns: []*schema.Column{APIKeysColumns[22]}, }, { Name: "apikey_status", @@ -66,15 +76,20 @@ var ( Unique: false, Columns: []*schema.Column{APIKeysColumns[3]}, }, + { + Name: "apikey_last_used_at", + Unique: false, + Columns: []*schema.Column{APIKeysColumns[7]}, + }, { Name: "apikey_quota_quota_used", Unique: false, - Columns: []*schema.Column{APIKeysColumns[9], APIKeysColumns[10]}, + Columns: []*schema.Column{APIKeysColumns[10], APIKeysColumns[11]}, }, { Name: "apikey_expires_at", Unique: false, - Columns: []*schema.Column{APIKeysColumns[11]}, + Columns: []*schema.Column{APIKeysColumns[12]}, }, }, } @@ -102,6 +117,8 @@ var ( {Name: "rate_limited_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "rate_limit_reset_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "overload_until", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "temp_unschedulable_until", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "temp_unschedulable_reason", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "session_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "session_window_end", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "session_window_status", Type: field.TypeString, Nullable: true, Size: 20}, @@ -115,7 +132,7 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "accounts_proxies_proxy", - Columns: []*schema.Column{AccountsColumns[25]}, + Columns: []*schema.Column{AccountsColumns[27]}, RefColumns: []*schema.Column{ProxiesColumns[0]}, OnDelete: schema.SetNull, }, @@ -139,7 +156,7 @@ var ( { Name: "account_proxy_id", Unique: false, - Columns: []*schema.Column{AccountsColumns[25]}, + Columns: []*schema.Column{AccountsColumns[27]}, }, { Name: "account_priority", @@ -171,6 +188,16 @@ var ( Unique: false, Columns: []*schema.Column{AccountsColumns[21]}, }, + { + Name: "account_platform_priority", + Unique: false, + Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[11]}, + }, + { + Name: "account_priority_status", + Unique: false, + Columns: []*schema.Column{AccountsColumns[11], AccountsColumns[13]}, + }, { Name: "account_deleted_at", Unique: false, @@ -366,6 +393,11 @@ var ( {Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_image_price_360", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_image_price_540", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_video_price_per_request", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_video_price_per_request_hd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0}, {Name: "claude_code_only", Type: field.TypeBool, Default: false}, {Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true}, {Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true}, @@ -409,7 +441,45 @@ var ( { Name: "group_sort_order", Unique: false, - Columns: []*schema.Column{GroupsColumns[25]}, + Columns: []*schema.Column{GroupsColumns[30]}, + }, + }, + } + // IdempotencyRecordsColumns holds the columns for the "idempotency_records" table. + IdempotencyRecordsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "scope", Type: field.TypeString, Size: 128}, + {Name: "idempotency_key_hash", Type: field.TypeString, Size: 64}, + {Name: "request_fingerprint", Type: field.TypeString, Size: 64}, + {Name: "status", Type: field.TypeString, Size: 32}, + {Name: "response_status", Type: field.TypeInt, Nullable: true}, + {Name: "response_body", Type: field.TypeString, Nullable: true}, + {Name: "error_reason", Type: field.TypeString, Nullable: true, Size: 128}, + {Name: "locked_until", Type: field.TypeTime, Nullable: true}, + {Name: "expires_at", Type: field.TypeTime}, + } + // IdempotencyRecordsTable holds the schema information for the "idempotency_records" table. + IdempotencyRecordsTable = &schema.Table{ + Name: "idempotency_records", + Columns: IdempotencyRecordsColumns, + PrimaryKey: []*schema.Column{IdempotencyRecordsColumns[0]}, + Indexes: []*schema.Index{ + { + Name: "idempotencyrecord_scope_idempotency_key_hash", + Unique: true, + Columns: []*schema.Column{IdempotencyRecordsColumns[3], IdempotencyRecordsColumns[4]}, + }, + { + Name: "idempotencyrecord_expires_at", + Unique: false, + Columns: []*schema.Column{IdempotencyRecordsColumns[11]}, + }, + { + Name: "idempotencyrecord_status_locked_until", + Unique: false, + Columns: []*schema.Column{IdempotencyRecordsColumns[6], IdempotencyRecordsColumns[10]}, }, }, } @@ -572,6 +642,20 @@ var ( }, }, } + // SecuritySecretsColumns holds the columns for the "security_secrets" table. + SecuritySecretsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "key", Type: field.TypeString, Unique: true, Size: 100}, + {Name: "value", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + } + // SecuritySecretsTable holds the schema information for the "security_secrets" table. + SecuritySecretsTable = &schema.Table{ + Name: "security_secrets", + Columns: SecuritySecretsColumns, + PrimaryKey: []*schema.Column{SecuritySecretsColumns[0]}, + } // SettingsColumns holds the columns for the "settings" table. SettingsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -650,6 +734,8 @@ var ( {Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45}, {Name: "image_count", Type: field.TypeInt, Default: 0}, {Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10}, + {Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16}, + {Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "api_key_id", Type: field.TypeInt64}, {Name: "account_id", Type: field.TypeInt64}, @@ -665,31 +751,31 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "usage_logs_api_keys_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[26]}, + Columns: []*schema.Column{UsageLogsColumns[28]}, RefColumns: []*schema.Column{APIKeysColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_accounts_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, RefColumns: []*schema.Column{AccountsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_groups_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "usage_logs_users_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_user_subscriptions_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[32]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, OnDelete: schema.SetNull, }, @@ -698,32 +784,32 @@ var ( { Name: "usagelog_user_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, }, { Name: "usagelog_api_key_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[26]}, + Columns: []*schema.Column{UsageLogsColumns[28]}, }, { Name: "usagelog_account_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, }, { Name: "usagelog_group_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, }, { Name: "usagelog_subscription_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[32]}, }, { Name: "usagelog_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[25]}, + Columns: []*schema.Column{UsageLogsColumns[27]}, }, { Name: "usagelog_model", @@ -738,12 +824,17 @@ var ( { Name: "usagelog_user_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[25]}, + Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]}, }, { Name: "usagelog_api_key_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[26], UsageLogsColumns[25]}, + Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]}, + }, + { + Name: "usagelog_group_id_created_at", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[27]}, }, }, } @@ -764,6 +855,8 @@ var ( {Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "totp_enabled", Type: field.TypeBool, Default: false}, {Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true}, + {Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0}, + {Name: "sora_storage_used_bytes", Type: field.TypeInt64, Default: 0}, } // UsersTable holds the schema information for the "users" table. UsersTable = &schema.Table{ @@ -969,6 +1062,11 @@ var ( Unique: false, Columns: []*schema.Column{UserSubscriptionsColumns[5]}, }, + { + Name: "usersubscription_user_id_status_expires_at", + Unique: false, + Columns: []*schema.Column{UserSubscriptionsColumns[16], UserSubscriptionsColumns[6], UserSubscriptionsColumns[5]}, + }, { Name: "usersubscription_assigned_by", Unique: false, @@ -995,10 +1093,12 @@ var ( AnnouncementReadsTable, ErrorPassthroughRulesTable, GroupsTable, + IdempotencyRecordsTable, PromoCodesTable, PromoCodeUsagesTable, ProxiesTable, RedeemCodesTable, + SecuritySecretsTable, SettingsTable, UsageCleanupTasksTable, UsageLogsTable, @@ -1039,6 +1139,9 @@ func init() { GroupsTable.Annotation = &entsql.Annotation{ Table: "groups", } + IdempotencyRecordsTable.Annotation = &entsql.Annotation{ + Table: "idempotency_records", + } PromoCodesTable.Annotation = &entsql.Annotation{ Table: "promo_codes", } @@ -1055,6 +1158,9 @@ func init() { RedeemCodesTable.Annotation = &entsql.Annotation{ Table: "redeem_codes", } + SecuritySecretsTable.Annotation = &entsql.Annotation{ + Table: "security_secrets", + } SettingsTable.Annotation = &entsql.Annotation{ Table: "settings", } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 76360820..85e2ea71 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -19,11 +19,13 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/setting" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" @@ -51,10 +53,12 @@ const ( TypeAnnouncementRead = "AnnouncementRead" TypeErrorPassthroughRule = "ErrorPassthroughRule" TypeGroup = "Group" + TypeIdempotencyRecord = "IdempotencyRecord" TypePromoCode = "PromoCode" TypePromoCodeUsage = "PromoCodeUsage" TypeProxy = "Proxy" TypeRedeemCode = "RedeemCode" + TypeSecuritySecret = "SecuritySecret" TypeSetting = "Setting" TypeUsageCleanupTask = "UsageCleanupTask" TypeUsageLog = "UsageLog" @@ -77,6 +81,7 @@ type APIKeyMutation struct { key *string name *string status *string + last_used_at *time.Time ip_whitelist *[]string appendip_whitelist []string ip_blacklist *[]string @@ -86,6 +91,21 @@ type APIKeyMutation struct { quota_used *float64 addquota_used *float64 expires_at *time.Time + rate_limit_5h *float64 + addrate_limit_5h *float64 + rate_limit_1d *float64 + addrate_limit_1d *float64 + rate_limit_7d *float64 + addrate_limit_7d *float64 + usage_5h *float64 + addusage_5h *float64 + usage_1d *float64 + addusage_1d *float64 + usage_7d *float64 + addusage_7d *float64 + window_5h_start *time.Time + window_1d_start *time.Time + window_7d_start *time.Time clearedFields map[string]struct{} user *int64 cleareduser bool @@ -511,6 +531,55 @@ func (m *APIKeyMutation) ResetStatus() { m.status = nil } +// SetLastUsedAt sets the "last_used_at" field. +func (m *APIKeyMutation) SetLastUsedAt(t time.Time) { + m.last_used_at = &t +} + +// LastUsedAt returns the value of the "last_used_at" field in the mutation. +func (m *APIKeyMutation) LastUsedAt() (r time.Time, exists bool) { + v := m.last_used_at + if v == nil { + return + } + return *v, true +} + +// OldLastUsedAt returns the old "last_used_at" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldLastUsedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLastUsedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLastUsedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLastUsedAt: %w", err) + } + return oldValue.LastUsedAt, nil +} + +// ClearLastUsedAt clears the value of the "last_used_at" field. +func (m *APIKeyMutation) ClearLastUsedAt() { + m.last_used_at = nil + m.clearedFields[apikey.FieldLastUsedAt] = struct{}{} +} + +// LastUsedAtCleared returns if the "last_used_at" field was cleared in this mutation. +func (m *APIKeyMutation) LastUsedAtCleared() bool { + _, ok := m.clearedFields[apikey.FieldLastUsedAt] + return ok +} + +// ResetLastUsedAt resets all changes to the "last_used_at" field. +func (m *APIKeyMutation) ResetLastUsedAt() { + m.last_used_at = nil + delete(m.clearedFields, apikey.FieldLastUsedAt) +} + // SetIPWhitelist sets the "ip_whitelist" field. func (m *APIKeyMutation) SetIPWhitelist(s []string) { m.ip_whitelist = &s @@ -802,6 +871,489 @@ func (m *APIKeyMutation) ResetExpiresAt() { delete(m.clearedFields, apikey.FieldExpiresAt) } +// SetRateLimit5h sets the "rate_limit_5h" field. +func (m *APIKeyMutation) SetRateLimit5h(f float64) { + m.rate_limit_5h = &f + m.addrate_limit_5h = nil +} + +// RateLimit5h returns the value of the "rate_limit_5h" field in the mutation. +func (m *APIKeyMutation) RateLimit5h() (r float64, exists bool) { + v := m.rate_limit_5h + if v == nil { + return + } + return *v, true +} + +// OldRateLimit5h returns the old "rate_limit_5h" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldRateLimit5h(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRateLimit5h is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRateLimit5h requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRateLimit5h: %w", err) + } + return oldValue.RateLimit5h, nil +} + +// AddRateLimit5h adds f to the "rate_limit_5h" field. +func (m *APIKeyMutation) AddRateLimit5h(f float64) { + if m.addrate_limit_5h != nil { + *m.addrate_limit_5h += f + } else { + m.addrate_limit_5h = &f + } +} + +// AddedRateLimit5h returns the value that was added to the "rate_limit_5h" field in this mutation. +func (m *APIKeyMutation) AddedRateLimit5h() (r float64, exists bool) { + v := m.addrate_limit_5h + if v == nil { + return + } + return *v, true +} + +// ResetRateLimit5h resets all changes to the "rate_limit_5h" field. +func (m *APIKeyMutation) ResetRateLimit5h() { + m.rate_limit_5h = nil + m.addrate_limit_5h = nil +} + +// SetRateLimit1d sets the "rate_limit_1d" field. +func (m *APIKeyMutation) SetRateLimit1d(f float64) { + m.rate_limit_1d = &f + m.addrate_limit_1d = nil +} + +// RateLimit1d returns the value of the "rate_limit_1d" field in the mutation. +func (m *APIKeyMutation) RateLimit1d() (r float64, exists bool) { + v := m.rate_limit_1d + if v == nil { + return + } + return *v, true +} + +// OldRateLimit1d returns the old "rate_limit_1d" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldRateLimit1d(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRateLimit1d is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRateLimit1d requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRateLimit1d: %w", err) + } + return oldValue.RateLimit1d, nil +} + +// AddRateLimit1d adds f to the "rate_limit_1d" field. +func (m *APIKeyMutation) AddRateLimit1d(f float64) { + if m.addrate_limit_1d != nil { + *m.addrate_limit_1d += f + } else { + m.addrate_limit_1d = &f + } +} + +// AddedRateLimit1d returns the value that was added to the "rate_limit_1d" field in this mutation. +func (m *APIKeyMutation) AddedRateLimit1d() (r float64, exists bool) { + v := m.addrate_limit_1d + if v == nil { + return + } + return *v, true +} + +// ResetRateLimit1d resets all changes to the "rate_limit_1d" field. +func (m *APIKeyMutation) ResetRateLimit1d() { + m.rate_limit_1d = nil + m.addrate_limit_1d = nil +} + +// SetRateLimit7d sets the "rate_limit_7d" field. +func (m *APIKeyMutation) SetRateLimit7d(f float64) { + m.rate_limit_7d = &f + m.addrate_limit_7d = nil +} + +// RateLimit7d returns the value of the "rate_limit_7d" field in the mutation. +func (m *APIKeyMutation) RateLimit7d() (r float64, exists bool) { + v := m.rate_limit_7d + if v == nil { + return + } + return *v, true +} + +// OldRateLimit7d returns the old "rate_limit_7d" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldRateLimit7d(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRateLimit7d is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRateLimit7d requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRateLimit7d: %w", err) + } + return oldValue.RateLimit7d, nil +} + +// AddRateLimit7d adds f to the "rate_limit_7d" field. +func (m *APIKeyMutation) AddRateLimit7d(f float64) { + if m.addrate_limit_7d != nil { + *m.addrate_limit_7d += f + } else { + m.addrate_limit_7d = &f + } +} + +// AddedRateLimit7d returns the value that was added to the "rate_limit_7d" field in this mutation. +func (m *APIKeyMutation) AddedRateLimit7d() (r float64, exists bool) { + v := m.addrate_limit_7d + if v == nil { + return + } + return *v, true +} + +// ResetRateLimit7d resets all changes to the "rate_limit_7d" field. +func (m *APIKeyMutation) ResetRateLimit7d() { + m.rate_limit_7d = nil + m.addrate_limit_7d = nil +} + +// SetUsage5h sets the "usage_5h" field. +func (m *APIKeyMutation) SetUsage5h(f float64) { + m.usage_5h = &f + m.addusage_5h = nil +} + +// Usage5h returns the value of the "usage_5h" field in the mutation. +func (m *APIKeyMutation) Usage5h() (r float64, exists bool) { + v := m.usage_5h + if v == nil { + return + } + return *v, true +} + +// OldUsage5h returns the old "usage_5h" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldUsage5h(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUsage5h is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUsage5h requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUsage5h: %w", err) + } + return oldValue.Usage5h, nil +} + +// AddUsage5h adds f to the "usage_5h" field. +func (m *APIKeyMutation) AddUsage5h(f float64) { + if m.addusage_5h != nil { + *m.addusage_5h += f + } else { + m.addusage_5h = &f + } +} + +// AddedUsage5h returns the value that was added to the "usage_5h" field in this mutation. +func (m *APIKeyMutation) AddedUsage5h() (r float64, exists bool) { + v := m.addusage_5h + if v == nil { + return + } + return *v, true +} + +// ResetUsage5h resets all changes to the "usage_5h" field. +func (m *APIKeyMutation) ResetUsage5h() { + m.usage_5h = nil + m.addusage_5h = nil +} + +// SetUsage1d sets the "usage_1d" field. +func (m *APIKeyMutation) SetUsage1d(f float64) { + m.usage_1d = &f + m.addusage_1d = nil +} + +// Usage1d returns the value of the "usage_1d" field in the mutation. +func (m *APIKeyMutation) Usage1d() (r float64, exists bool) { + v := m.usage_1d + if v == nil { + return + } + return *v, true +} + +// OldUsage1d returns the old "usage_1d" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldUsage1d(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUsage1d is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUsage1d requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUsage1d: %w", err) + } + return oldValue.Usage1d, nil +} + +// AddUsage1d adds f to the "usage_1d" field. +func (m *APIKeyMutation) AddUsage1d(f float64) { + if m.addusage_1d != nil { + *m.addusage_1d += f + } else { + m.addusage_1d = &f + } +} + +// AddedUsage1d returns the value that was added to the "usage_1d" field in this mutation. +func (m *APIKeyMutation) AddedUsage1d() (r float64, exists bool) { + v := m.addusage_1d + if v == nil { + return + } + return *v, true +} + +// ResetUsage1d resets all changes to the "usage_1d" field. +func (m *APIKeyMutation) ResetUsage1d() { + m.usage_1d = nil + m.addusage_1d = nil +} + +// SetUsage7d sets the "usage_7d" field. +func (m *APIKeyMutation) SetUsage7d(f float64) { + m.usage_7d = &f + m.addusage_7d = nil +} + +// Usage7d returns the value of the "usage_7d" field in the mutation. +func (m *APIKeyMutation) Usage7d() (r float64, exists bool) { + v := m.usage_7d + if v == nil { + return + } + return *v, true +} + +// OldUsage7d returns the old "usage_7d" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldUsage7d(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUsage7d is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUsage7d requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUsage7d: %w", err) + } + return oldValue.Usage7d, nil +} + +// AddUsage7d adds f to the "usage_7d" field. +func (m *APIKeyMutation) AddUsage7d(f float64) { + if m.addusage_7d != nil { + *m.addusage_7d += f + } else { + m.addusage_7d = &f + } +} + +// AddedUsage7d returns the value that was added to the "usage_7d" field in this mutation. +func (m *APIKeyMutation) AddedUsage7d() (r float64, exists bool) { + v := m.addusage_7d + if v == nil { + return + } + return *v, true +} + +// ResetUsage7d resets all changes to the "usage_7d" field. +func (m *APIKeyMutation) ResetUsage7d() { + m.usage_7d = nil + m.addusage_7d = nil +} + +// SetWindow5hStart sets the "window_5h_start" field. +func (m *APIKeyMutation) SetWindow5hStart(t time.Time) { + m.window_5h_start = &t +} + +// Window5hStart returns the value of the "window_5h_start" field in the mutation. +func (m *APIKeyMutation) Window5hStart() (r time.Time, exists bool) { + v := m.window_5h_start + if v == nil { + return + } + return *v, true +} + +// OldWindow5hStart returns the old "window_5h_start" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldWindow5hStart(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldWindow5hStart is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldWindow5hStart requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldWindow5hStart: %w", err) + } + return oldValue.Window5hStart, nil +} + +// ClearWindow5hStart clears the value of the "window_5h_start" field. +func (m *APIKeyMutation) ClearWindow5hStart() { + m.window_5h_start = nil + m.clearedFields[apikey.FieldWindow5hStart] = struct{}{} +} + +// Window5hStartCleared returns if the "window_5h_start" field was cleared in this mutation. +func (m *APIKeyMutation) Window5hStartCleared() bool { + _, ok := m.clearedFields[apikey.FieldWindow5hStart] + return ok +} + +// ResetWindow5hStart resets all changes to the "window_5h_start" field. +func (m *APIKeyMutation) ResetWindow5hStart() { + m.window_5h_start = nil + delete(m.clearedFields, apikey.FieldWindow5hStart) +} + +// SetWindow1dStart sets the "window_1d_start" field. +func (m *APIKeyMutation) SetWindow1dStart(t time.Time) { + m.window_1d_start = &t +} + +// Window1dStart returns the value of the "window_1d_start" field in the mutation. +func (m *APIKeyMutation) Window1dStart() (r time.Time, exists bool) { + v := m.window_1d_start + if v == nil { + return + } + return *v, true +} + +// OldWindow1dStart returns the old "window_1d_start" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldWindow1dStart(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldWindow1dStart is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldWindow1dStart requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldWindow1dStart: %w", err) + } + return oldValue.Window1dStart, nil +} + +// ClearWindow1dStart clears the value of the "window_1d_start" field. +func (m *APIKeyMutation) ClearWindow1dStart() { + m.window_1d_start = nil + m.clearedFields[apikey.FieldWindow1dStart] = struct{}{} +} + +// Window1dStartCleared returns if the "window_1d_start" field was cleared in this mutation. +func (m *APIKeyMutation) Window1dStartCleared() bool { + _, ok := m.clearedFields[apikey.FieldWindow1dStart] + return ok +} + +// ResetWindow1dStart resets all changes to the "window_1d_start" field. +func (m *APIKeyMutation) ResetWindow1dStart() { + m.window_1d_start = nil + delete(m.clearedFields, apikey.FieldWindow1dStart) +} + +// SetWindow7dStart sets the "window_7d_start" field. +func (m *APIKeyMutation) SetWindow7dStart(t time.Time) { + m.window_7d_start = &t +} + +// Window7dStart returns the value of the "window_7d_start" field in the mutation. +func (m *APIKeyMutation) Window7dStart() (r time.Time, exists bool) { + v := m.window_7d_start + if v == nil { + return + } + return *v, true +} + +// OldWindow7dStart returns the old "window_7d_start" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldWindow7dStart(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldWindow7dStart is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldWindow7dStart requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldWindow7dStart: %w", err) + } + return oldValue.Window7dStart, nil +} + +// ClearWindow7dStart clears the value of the "window_7d_start" field. +func (m *APIKeyMutation) ClearWindow7dStart() { + m.window_7d_start = nil + m.clearedFields[apikey.FieldWindow7dStart] = struct{}{} +} + +// Window7dStartCleared returns if the "window_7d_start" field was cleared in this mutation. +func (m *APIKeyMutation) Window7dStartCleared() bool { + _, ok := m.clearedFields[apikey.FieldWindow7dStart] + return ok +} + +// ResetWindow7dStart resets all changes to the "window_7d_start" field. +func (m *APIKeyMutation) ResetWindow7dStart() { + m.window_7d_start = nil + delete(m.clearedFields, apikey.FieldWindow7dStart) +} + // ClearUser clears the "user" edge to the User entity. func (m *APIKeyMutation) ClearUser() { m.cleareduser = true @@ -944,7 +1496,7 @@ func (m *APIKeyMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *APIKeyMutation) Fields() []string { - fields := make([]string, 0, 13) + fields := make([]string, 0, 23) if m.created_at != nil { fields = append(fields, apikey.FieldCreatedAt) } @@ -969,6 +1521,9 @@ func (m *APIKeyMutation) Fields() []string { if m.status != nil { fields = append(fields, apikey.FieldStatus) } + if m.last_used_at != nil { + fields = append(fields, apikey.FieldLastUsedAt) + } if m.ip_whitelist != nil { fields = append(fields, apikey.FieldIPWhitelist) } @@ -984,6 +1539,33 @@ func (m *APIKeyMutation) Fields() []string { if m.expires_at != nil { fields = append(fields, apikey.FieldExpiresAt) } + if m.rate_limit_5h != nil { + fields = append(fields, apikey.FieldRateLimit5h) + } + if m.rate_limit_1d != nil { + fields = append(fields, apikey.FieldRateLimit1d) + } + if m.rate_limit_7d != nil { + fields = append(fields, apikey.FieldRateLimit7d) + } + if m.usage_5h != nil { + fields = append(fields, apikey.FieldUsage5h) + } + if m.usage_1d != nil { + fields = append(fields, apikey.FieldUsage1d) + } + if m.usage_7d != nil { + fields = append(fields, apikey.FieldUsage7d) + } + if m.window_5h_start != nil { + fields = append(fields, apikey.FieldWindow5hStart) + } + if m.window_1d_start != nil { + fields = append(fields, apikey.FieldWindow1dStart) + } + if m.window_7d_start != nil { + fields = append(fields, apikey.FieldWindow7dStart) + } return fields } @@ -1008,6 +1590,8 @@ func (m *APIKeyMutation) Field(name string) (ent.Value, bool) { return m.GroupID() case apikey.FieldStatus: return m.Status() + case apikey.FieldLastUsedAt: + return m.LastUsedAt() case apikey.FieldIPWhitelist: return m.IPWhitelist() case apikey.FieldIPBlacklist: @@ -1018,6 +1602,24 @@ func (m *APIKeyMutation) Field(name string) (ent.Value, bool) { return m.QuotaUsed() case apikey.FieldExpiresAt: return m.ExpiresAt() + case apikey.FieldRateLimit5h: + return m.RateLimit5h() + case apikey.FieldRateLimit1d: + return m.RateLimit1d() + case apikey.FieldRateLimit7d: + return m.RateLimit7d() + case apikey.FieldUsage5h: + return m.Usage5h() + case apikey.FieldUsage1d: + return m.Usage1d() + case apikey.FieldUsage7d: + return m.Usage7d() + case apikey.FieldWindow5hStart: + return m.Window5hStart() + case apikey.FieldWindow1dStart: + return m.Window1dStart() + case apikey.FieldWindow7dStart: + return m.Window7dStart() } return nil, false } @@ -1043,6 +1645,8 @@ func (m *APIKeyMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldGroupID(ctx) case apikey.FieldStatus: return m.OldStatus(ctx) + case apikey.FieldLastUsedAt: + return m.OldLastUsedAt(ctx) case apikey.FieldIPWhitelist: return m.OldIPWhitelist(ctx) case apikey.FieldIPBlacklist: @@ -1053,6 +1657,24 @@ func (m *APIKeyMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldQuotaUsed(ctx) case apikey.FieldExpiresAt: return m.OldExpiresAt(ctx) + case apikey.FieldRateLimit5h: + return m.OldRateLimit5h(ctx) + case apikey.FieldRateLimit1d: + return m.OldRateLimit1d(ctx) + case apikey.FieldRateLimit7d: + return m.OldRateLimit7d(ctx) + case apikey.FieldUsage5h: + return m.OldUsage5h(ctx) + case apikey.FieldUsage1d: + return m.OldUsage1d(ctx) + case apikey.FieldUsage7d: + return m.OldUsage7d(ctx) + case apikey.FieldWindow5hStart: + return m.OldWindow5hStart(ctx) + case apikey.FieldWindow1dStart: + return m.OldWindow1dStart(ctx) + case apikey.FieldWindow7dStart: + return m.OldWindow7dStart(ctx) } return nil, fmt.Errorf("unknown APIKey field %s", name) } @@ -1118,6 +1740,13 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error { } m.SetStatus(v) return nil + case apikey.FieldLastUsedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLastUsedAt(v) + return nil case apikey.FieldIPWhitelist: v, ok := value.([]string) if !ok { @@ -1153,6 +1782,69 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error { } m.SetExpiresAt(v) return nil + case apikey.FieldRateLimit5h: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRateLimit5h(v) + return nil + case apikey.FieldRateLimit1d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRateLimit1d(v) + return nil + case apikey.FieldRateLimit7d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRateLimit7d(v) + return nil + case apikey.FieldUsage5h: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUsage5h(v) + return nil + case apikey.FieldUsage1d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUsage1d(v) + return nil + case apikey.FieldUsage7d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUsage7d(v) + return nil + case apikey.FieldWindow5hStart: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetWindow5hStart(v) + return nil + case apikey.FieldWindow1dStart: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetWindow1dStart(v) + return nil + case apikey.FieldWindow7dStart: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetWindow7dStart(v) + return nil } return fmt.Errorf("unknown APIKey field %s", name) } @@ -1167,6 +1859,24 @@ func (m *APIKeyMutation) AddedFields() []string { if m.addquota_used != nil { fields = append(fields, apikey.FieldQuotaUsed) } + if m.addrate_limit_5h != nil { + fields = append(fields, apikey.FieldRateLimit5h) + } + if m.addrate_limit_1d != nil { + fields = append(fields, apikey.FieldRateLimit1d) + } + if m.addrate_limit_7d != nil { + fields = append(fields, apikey.FieldRateLimit7d) + } + if m.addusage_5h != nil { + fields = append(fields, apikey.FieldUsage5h) + } + if m.addusage_1d != nil { + fields = append(fields, apikey.FieldUsage1d) + } + if m.addusage_7d != nil { + fields = append(fields, apikey.FieldUsage7d) + } return fields } @@ -1179,6 +1889,18 @@ func (m *APIKeyMutation) AddedField(name string) (ent.Value, bool) { return m.AddedQuota() case apikey.FieldQuotaUsed: return m.AddedQuotaUsed() + case apikey.FieldRateLimit5h: + return m.AddedRateLimit5h() + case apikey.FieldRateLimit1d: + return m.AddedRateLimit1d() + case apikey.FieldRateLimit7d: + return m.AddedRateLimit7d() + case apikey.FieldUsage5h: + return m.AddedUsage5h() + case apikey.FieldUsage1d: + return m.AddedUsage1d() + case apikey.FieldUsage7d: + return m.AddedUsage7d() } return nil, false } @@ -1202,6 +1924,48 @@ func (m *APIKeyMutation) AddField(name string, value ent.Value) error { } m.AddQuotaUsed(v) return nil + case apikey.FieldRateLimit5h: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRateLimit5h(v) + return nil + case apikey.FieldRateLimit1d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRateLimit1d(v) + return nil + case apikey.FieldRateLimit7d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRateLimit7d(v) + return nil + case apikey.FieldUsage5h: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddUsage5h(v) + return nil + case apikey.FieldUsage1d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddUsage1d(v) + return nil + case apikey.FieldUsage7d: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddUsage7d(v) + return nil } return fmt.Errorf("unknown APIKey numeric field %s", name) } @@ -1216,6 +1980,9 @@ func (m *APIKeyMutation) ClearedFields() []string { if m.FieldCleared(apikey.FieldGroupID) { fields = append(fields, apikey.FieldGroupID) } + if m.FieldCleared(apikey.FieldLastUsedAt) { + fields = append(fields, apikey.FieldLastUsedAt) + } if m.FieldCleared(apikey.FieldIPWhitelist) { fields = append(fields, apikey.FieldIPWhitelist) } @@ -1225,6 +1992,15 @@ func (m *APIKeyMutation) ClearedFields() []string { if m.FieldCleared(apikey.FieldExpiresAt) { fields = append(fields, apikey.FieldExpiresAt) } + if m.FieldCleared(apikey.FieldWindow5hStart) { + fields = append(fields, apikey.FieldWindow5hStart) + } + if m.FieldCleared(apikey.FieldWindow1dStart) { + fields = append(fields, apikey.FieldWindow1dStart) + } + if m.FieldCleared(apikey.FieldWindow7dStart) { + fields = append(fields, apikey.FieldWindow7dStart) + } return fields } @@ -1245,6 +2021,9 @@ func (m *APIKeyMutation) ClearField(name string) error { case apikey.FieldGroupID: m.ClearGroupID() return nil + case apikey.FieldLastUsedAt: + m.ClearLastUsedAt() + return nil case apikey.FieldIPWhitelist: m.ClearIPWhitelist() return nil @@ -1254,6 +2033,15 @@ func (m *APIKeyMutation) ClearField(name string) error { case apikey.FieldExpiresAt: m.ClearExpiresAt() return nil + case apikey.FieldWindow5hStart: + m.ClearWindow5hStart() + return nil + case apikey.FieldWindow1dStart: + m.ClearWindow1dStart() + return nil + case apikey.FieldWindow7dStart: + m.ClearWindow7dStart() + return nil } return fmt.Errorf("unknown APIKey nullable field %s", name) } @@ -1286,6 +2074,9 @@ func (m *APIKeyMutation) ResetField(name string) error { case apikey.FieldStatus: m.ResetStatus() return nil + case apikey.FieldLastUsedAt: + m.ResetLastUsedAt() + return nil case apikey.FieldIPWhitelist: m.ResetIPWhitelist() return nil @@ -1301,6 +2092,33 @@ func (m *APIKeyMutation) ResetField(name string) error { case apikey.FieldExpiresAt: m.ResetExpiresAt() return nil + case apikey.FieldRateLimit5h: + m.ResetRateLimit5h() + return nil + case apikey.FieldRateLimit1d: + m.ResetRateLimit1d() + return nil + case apikey.FieldRateLimit7d: + m.ResetRateLimit7d() + return nil + case apikey.FieldUsage5h: + m.ResetUsage5h() + return nil + case apikey.FieldUsage1d: + m.ResetUsage1d() + return nil + case apikey.FieldUsage7d: + m.ResetUsage7d() + return nil + case apikey.FieldWindow5hStart: + m.ResetWindow5hStart() + return nil + case apikey.FieldWindow1dStart: + m.ResetWindow1dStart() + return nil + case apikey.FieldWindow7dStart: + m.ResetWindow7dStart() + return nil } return fmt.Errorf("unknown APIKey field %s", name) } @@ -1428,48 +2246,50 @@ func (m *APIKeyMutation) ResetEdge(name string) error { // AccountMutation represents an operation that mutates the Account nodes in the graph. type AccountMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - name *string - notes *string - platform *string - _type *string - credentials *map[string]interface{} - extra *map[string]interface{} - concurrency *int - addconcurrency *int - priority *int - addpriority *int - rate_multiplier *float64 - addrate_multiplier *float64 - status *string - error_message *string - last_used_at *time.Time - expires_at *time.Time - auto_pause_on_expired *bool - schedulable *bool - rate_limited_at *time.Time - rate_limit_reset_at *time.Time - overload_until *time.Time - session_window_start *time.Time - session_window_end *time.Time - session_window_status *string - clearedFields map[string]struct{} - groups map[int64]struct{} - removedgroups map[int64]struct{} - clearedgroups bool - proxy *int64 - clearedproxy bool - usage_logs map[int64]struct{} - removedusage_logs map[int64]struct{} - clearedusage_logs bool - done bool - oldValue func(context.Context) (*Account, error) - predicates []predicate.Account + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + notes *string + platform *string + _type *string + credentials *map[string]interface{} + extra *map[string]interface{} + concurrency *int + addconcurrency *int + priority *int + addpriority *int + rate_multiplier *float64 + addrate_multiplier *float64 + status *string + error_message *string + last_used_at *time.Time + expires_at *time.Time + auto_pause_on_expired *bool + schedulable *bool + rate_limited_at *time.Time + rate_limit_reset_at *time.Time + overload_until *time.Time + temp_unschedulable_until *time.Time + temp_unschedulable_reason *string + session_window_start *time.Time + session_window_end *time.Time + session_window_status *string + clearedFields map[string]struct{} + groups map[int64]struct{} + removedgroups map[int64]struct{} + clearedgroups bool + proxy *int64 + clearedproxy bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool + done bool + oldValue func(context.Context) (*Account, error) + predicates []predicate.Account } var _ ent.Mutation = (*AccountMutation)(nil) @@ -2539,6 +3359,104 @@ func (m *AccountMutation) ResetOverloadUntil() { delete(m.clearedFields, account.FieldOverloadUntil) } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (m *AccountMutation) SetTempUnschedulableUntil(t time.Time) { + m.temp_unschedulable_until = &t +} + +// TempUnschedulableUntil returns the value of the "temp_unschedulable_until" field in the mutation. +func (m *AccountMutation) TempUnschedulableUntil() (r time.Time, exists bool) { + v := m.temp_unschedulable_until + if v == nil { + return + } + return *v, true +} + +// OldTempUnschedulableUntil returns the old "temp_unschedulable_until" field's value of the Account entity. +// If the Account object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AccountMutation) OldTempUnschedulableUntil(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTempUnschedulableUntil is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTempUnschedulableUntil requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTempUnschedulableUntil: %w", err) + } + return oldValue.TempUnschedulableUntil, nil +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (m *AccountMutation) ClearTempUnschedulableUntil() { + m.temp_unschedulable_until = nil + m.clearedFields[account.FieldTempUnschedulableUntil] = struct{}{} +} + +// TempUnschedulableUntilCleared returns if the "temp_unschedulable_until" field was cleared in this mutation. +func (m *AccountMutation) TempUnschedulableUntilCleared() bool { + _, ok := m.clearedFields[account.FieldTempUnschedulableUntil] + return ok +} + +// ResetTempUnschedulableUntil resets all changes to the "temp_unschedulable_until" field. +func (m *AccountMutation) ResetTempUnschedulableUntil() { + m.temp_unschedulable_until = nil + delete(m.clearedFields, account.FieldTempUnschedulableUntil) +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (m *AccountMutation) SetTempUnschedulableReason(s string) { + m.temp_unschedulable_reason = &s +} + +// TempUnschedulableReason returns the value of the "temp_unschedulable_reason" field in the mutation. +func (m *AccountMutation) TempUnschedulableReason() (r string, exists bool) { + v := m.temp_unschedulable_reason + if v == nil { + return + } + return *v, true +} + +// OldTempUnschedulableReason returns the old "temp_unschedulable_reason" field's value of the Account entity. +// If the Account object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AccountMutation) OldTempUnschedulableReason(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTempUnschedulableReason is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTempUnschedulableReason requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTempUnschedulableReason: %w", err) + } + return oldValue.TempUnschedulableReason, nil +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (m *AccountMutation) ClearTempUnschedulableReason() { + m.temp_unschedulable_reason = nil + m.clearedFields[account.FieldTempUnschedulableReason] = struct{}{} +} + +// TempUnschedulableReasonCleared returns if the "temp_unschedulable_reason" field was cleared in this mutation. +func (m *AccountMutation) TempUnschedulableReasonCleared() bool { + _, ok := m.clearedFields[account.FieldTempUnschedulableReason] + return ok +} + +// ResetTempUnschedulableReason resets all changes to the "temp_unschedulable_reason" field. +func (m *AccountMutation) ResetTempUnschedulableReason() { + m.temp_unschedulable_reason = nil + delete(m.clearedFields, account.FieldTempUnschedulableReason) +} + // SetSessionWindowStart sets the "session_window_start" field. func (m *AccountMutation) SetSessionWindowStart(t time.Time) { m.session_window_start = &t @@ -2855,7 +3773,7 @@ func (m *AccountMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AccountMutation) Fields() []string { - fields := make([]string, 0, 25) + fields := make([]string, 0, 27) if m.created_at != nil { fields = append(fields, account.FieldCreatedAt) } @@ -2922,6 +3840,12 @@ func (m *AccountMutation) Fields() []string { if m.overload_until != nil { fields = append(fields, account.FieldOverloadUntil) } + if m.temp_unschedulable_until != nil { + fields = append(fields, account.FieldTempUnschedulableUntil) + } + if m.temp_unschedulable_reason != nil { + fields = append(fields, account.FieldTempUnschedulableReason) + } if m.session_window_start != nil { fields = append(fields, account.FieldSessionWindowStart) } @@ -2983,6 +3907,10 @@ func (m *AccountMutation) Field(name string) (ent.Value, bool) { return m.RateLimitResetAt() case account.FieldOverloadUntil: return m.OverloadUntil() + case account.FieldTempUnschedulableUntil: + return m.TempUnschedulableUntil() + case account.FieldTempUnschedulableReason: + return m.TempUnschedulableReason() case account.FieldSessionWindowStart: return m.SessionWindowStart() case account.FieldSessionWindowEnd: @@ -3042,6 +3970,10 @@ func (m *AccountMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldRateLimitResetAt(ctx) case account.FieldOverloadUntil: return m.OldOverloadUntil(ctx) + case account.FieldTempUnschedulableUntil: + return m.OldTempUnschedulableUntil(ctx) + case account.FieldTempUnschedulableReason: + return m.OldTempUnschedulableReason(ctx) case account.FieldSessionWindowStart: return m.OldSessionWindowStart(ctx) case account.FieldSessionWindowEnd: @@ -3211,6 +4143,20 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error { } m.SetOverloadUntil(v) return nil + case account.FieldTempUnschedulableUntil: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTempUnschedulableUntil(v) + return nil + case account.FieldTempUnschedulableReason: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTempUnschedulableReason(v) + return nil case account.FieldSessionWindowStart: v, ok := value.(time.Time) if !ok { @@ -3328,6 +4274,12 @@ func (m *AccountMutation) ClearedFields() []string { if m.FieldCleared(account.FieldOverloadUntil) { fields = append(fields, account.FieldOverloadUntil) } + if m.FieldCleared(account.FieldTempUnschedulableUntil) { + fields = append(fields, account.FieldTempUnschedulableUntil) + } + if m.FieldCleared(account.FieldTempUnschedulableReason) { + fields = append(fields, account.FieldTempUnschedulableReason) + } if m.FieldCleared(account.FieldSessionWindowStart) { fields = append(fields, account.FieldSessionWindowStart) } @@ -3378,6 +4330,12 @@ func (m *AccountMutation) ClearField(name string) error { case account.FieldOverloadUntil: m.ClearOverloadUntil() return nil + case account.FieldTempUnschedulableUntil: + m.ClearTempUnschedulableUntil() + return nil + case account.FieldTempUnschedulableReason: + m.ClearTempUnschedulableReason() + return nil case account.FieldSessionWindowStart: m.ClearSessionWindowStart() return nil @@ -3461,6 +4419,12 @@ func (m *AccountMutation) ResetField(name string) error { case account.FieldOverloadUntil: m.ResetOverloadUntil() return nil + case account.FieldTempUnschedulableUntil: + m.ResetTempUnschedulableUntil() + return nil + case account.FieldTempUnschedulableReason: + m.ResetTempUnschedulableReason() + return nil case account.FieldSessionWindowStart: m.ResetSessionWindowStart() return nil @@ -7103,6 +8067,16 @@ type GroupMutation struct { addimage_price_2k *float64 image_price_4k *float64 addimage_price_4k *float64 + sora_image_price_360 *float64 + addsora_image_price_360 *float64 + sora_image_price_540 *float64 + addsora_image_price_540 *float64 + sora_video_price_per_request *float64 + addsora_video_price_per_request *float64 + sora_video_price_per_request_hd *float64 + addsora_video_price_per_request_hd *float64 + sora_storage_quota_bytes *int64 + addsora_storage_quota_bytes *int64 claude_code_only *bool fallback_group_id *int64 addfallback_group_id *int64 @@ -8119,6 +9093,342 @@ func (m *GroupMutation) ResetImagePrice4k() { delete(m.clearedFields, group.FieldImagePrice4k) } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (m *GroupMutation) SetSoraImagePrice360(f float64) { + m.sora_image_price_360 = &f + m.addsora_image_price_360 = nil +} + +// SoraImagePrice360 returns the value of the "sora_image_price_360" field in the mutation. +func (m *GroupMutation) SoraImagePrice360() (r float64, exists bool) { + v := m.sora_image_price_360 + if v == nil { + return + } + return *v, true +} + +// OldSoraImagePrice360 returns the old "sora_image_price_360" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraImagePrice360(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraImagePrice360 is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraImagePrice360 requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraImagePrice360: %w", err) + } + return oldValue.SoraImagePrice360, nil +} + +// AddSoraImagePrice360 adds f to the "sora_image_price_360" field. +func (m *GroupMutation) AddSoraImagePrice360(f float64) { + if m.addsora_image_price_360 != nil { + *m.addsora_image_price_360 += f + } else { + m.addsora_image_price_360 = &f + } +} + +// AddedSoraImagePrice360 returns the value that was added to the "sora_image_price_360" field in this mutation. +func (m *GroupMutation) AddedSoraImagePrice360() (r float64, exists bool) { + v := m.addsora_image_price_360 + if v == nil { + return + } + return *v, true +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (m *GroupMutation) ClearSoraImagePrice360() { + m.sora_image_price_360 = nil + m.addsora_image_price_360 = nil + m.clearedFields[group.FieldSoraImagePrice360] = struct{}{} +} + +// SoraImagePrice360Cleared returns if the "sora_image_price_360" field was cleared in this mutation. +func (m *GroupMutation) SoraImagePrice360Cleared() bool { + _, ok := m.clearedFields[group.FieldSoraImagePrice360] + return ok +} + +// ResetSoraImagePrice360 resets all changes to the "sora_image_price_360" field. +func (m *GroupMutation) ResetSoraImagePrice360() { + m.sora_image_price_360 = nil + m.addsora_image_price_360 = nil + delete(m.clearedFields, group.FieldSoraImagePrice360) +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (m *GroupMutation) SetSoraImagePrice540(f float64) { + m.sora_image_price_540 = &f + m.addsora_image_price_540 = nil +} + +// SoraImagePrice540 returns the value of the "sora_image_price_540" field in the mutation. +func (m *GroupMutation) SoraImagePrice540() (r float64, exists bool) { + v := m.sora_image_price_540 + if v == nil { + return + } + return *v, true +} + +// OldSoraImagePrice540 returns the old "sora_image_price_540" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraImagePrice540(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraImagePrice540 is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraImagePrice540 requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraImagePrice540: %w", err) + } + return oldValue.SoraImagePrice540, nil +} + +// AddSoraImagePrice540 adds f to the "sora_image_price_540" field. +func (m *GroupMutation) AddSoraImagePrice540(f float64) { + if m.addsora_image_price_540 != nil { + *m.addsora_image_price_540 += f + } else { + m.addsora_image_price_540 = &f + } +} + +// AddedSoraImagePrice540 returns the value that was added to the "sora_image_price_540" field in this mutation. +func (m *GroupMutation) AddedSoraImagePrice540() (r float64, exists bool) { + v := m.addsora_image_price_540 + if v == nil { + return + } + return *v, true +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (m *GroupMutation) ClearSoraImagePrice540() { + m.sora_image_price_540 = nil + m.addsora_image_price_540 = nil + m.clearedFields[group.FieldSoraImagePrice540] = struct{}{} +} + +// SoraImagePrice540Cleared returns if the "sora_image_price_540" field was cleared in this mutation. +func (m *GroupMutation) SoraImagePrice540Cleared() bool { + _, ok := m.clearedFields[group.FieldSoraImagePrice540] + return ok +} + +// ResetSoraImagePrice540 resets all changes to the "sora_image_price_540" field. +func (m *GroupMutation) ResetSoraImagePrice540() { + m.sora_image_price_540 = nil + m.addsora_image_price_540 = nil + delete(m.clearedFields, group.FieldSoraImagePrice540) +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (m *GroupMutation) SetSoraVideoPricePerRequest(f float64) { + m.sora_video_price_per_request = &f + m.addsora_video_price_per_request = nil +} + +// SoraVideoPricePerRequest returns the value of the "sora_video_price_per_request" field in the mutation. +func (m *GroupMutation) SoraVideoPricePerRequest() (r float64, exists bool) { + v := m.sora_video_price_per_request + if v == nil { + return + } + return *v, true +} + +// OldSoraVideoPricePerRequest returns the old "sora_video_price_per_request" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraVideoPricePerRequest(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraVideoPricePerRequest is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraVideoPricePerRequest requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequest: %w", err) + } + return oldValue.SoraVideoPricePerRequest, nil +} + +// AddSoraVideoPricePerRequest adds f to the "sora_video_price_per_request" field. +func (m *GroupMutation) AddSoraVideoPricePerRequest(f float64) { + if m.addsora_video_price_per_request != nil { + *m.addsora_video_price_per_request += f + } else { + m.addsora_video_price_per_request = &f + } +} + +// AddedSoraVideoPricePerRequest returns the value that was added to the "sora_video_price_per_request" field in this mutation. +func (m *GroupMutation) AddedSoraVideoPricePerRequest() (r float64, exists bool) { + v := m.addsora_video_price_per_request + if v == nil { + return + } + return *v, true +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (m *GroupMutation) ClearSoraVideoPricePerRequest() { + m.sora_video_price_per_request = nil + m.addsora_video_price_per_request = nil + m.clearedFields[group.FieldSoraVideoPricePerRequest] = struct{}{} +} + +// SoraVideoPricePerRequestCleared returns if the "sora_video_price_per_request" field was cleared in this mutation. +func (m *GroupMutation) SoraVideoPricePerRequestCleared() bool { + _, ok := m.clearedFields[group.FieldSoraVideoPricePerRequest] + return ok +} + +// ResetSoraVideoPricePerRequest resets all changes to the "sora_video_price_per_request" field. +func (m *GroupMutation) ResetSoraVideoPricePerRequest() { + m.sora_video_price_per_request = nil + m.addsora_video_price_per_request = nil + delete(m.clearedFields, group.FieldSoraVideoPricePerRequest) +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (m *GroupMutation) SetSoraVideoPricePerRequestHd(f float64) { + m.sora_video_price_per_request_hd = &f + m.addsora_video_price_per_request_hd = nil +} + +// SoraVideoPricePerRequestHd returns the value of the "sora_video_price_per_request_hd" field in the mutation. +func (m *GroupMutation) SoraVideoPricePerRequestHd() (r float64, exists bool) { + v := m.sora_video_price_per_request_hd + if v == nil { + return + } + return *v, true +} + +// OldSoraVideoPricePerRequestHd returns the old "sora_video_price_per_request_hd" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraVideoPricePerRequestHd(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraVideoPricePerRequestHd is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraVideoPricePerRequestHd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequestHd: %w", err) + } + return oldValue.SoraVideoPricePerRequestHd, nil +} + +// AddSoraVideoPricePerRequestHd adds f to the "sora_video_price_per_request_hd" field. +func (m *GroupMutation) AddSoraVideoPricePerRequestHd(f float64) { + if m.addsora_video_price_per_request_hd != nil { + *m.addsora_video_price_per_request_hd += f + } else { + m.addsora_video_price_per_request_hd = &f + } +} + +// AddedSoraVideoPricePerRequestHd returns the value that was added to the "sora_video_price_per_request_hd" field in this mutation. +func (m *GroupMutation) AddedSoraVideoPricePerRequestHd() (r float64, exists bool) { + v := m.addsora_video_price_per_request_hd + if v == nil { + return + } + return *v, true +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (m *GroupMutation) ClearSoraVideoPricePerRequestHd() { + m.sora_video_price_per_request_hd = nil + m.addsora_video_price_per_request_hd = nil + m.clearedFields[group.FieldSoraVideoPricePerRequestHd] = struct{}{} +} + +// SoraVideoPricePerRequestHdCleared returns if the "sora_video_price_per_request_hd" field was cleared in this mutation. +func (m *GroupMutation) SoraVideoPricePerRequestHdCleared() bool { + _, ok := m.clearedFields[group.FieldSoraVideoPricePerRequestHd] + return ok +} + +// ResetSoraVideoPricePerRequestHd resets all changes to the "sora_video_price_per_request_hd" field. +func (m *GroupMutation) ResetSoraVideoPricePerRequestHd() { + m.sora_video_price_per_request_hd = nil + m.addsora_video_price_per_request_hd = nil + delete(m.clearedFields, group.FieldSoraVideoPricePerRequestHd) +} + +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (m *GroupMutation) SetSoraStorageQuotaBytes(i int64) { + m.sora_storage_quota_bytes = &i + m.addsora_storage_quota_bytes = nil +} + +// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation. +func (m *GroupMutation) SoraStorageQuotaBytes() (r int64, exists bool) { + v := m.sora_storage_quota_bytes + if v == nil { + return + } + return *v, true +} + +// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err) + } + return oldValue.SoraStorageQuotaBytes, nil +} + +// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field. +func (m *GroupMutation) AddSoraStorageQuotaBytes(i int64) { + if m.addsora_storage_quota_bytes != nil { + *m.addsora_storage_quota_bytes += i + } else { + m.addsora_storage_quota_bytes = &i + } +} + +// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation. +func (m *GroupMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) { + v := m.addsora_storage_quota_bytes + if v == nil { + return + } + return *v, true +} + +// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field. +func (m *GroupMutation) ResetSoraStorageQuotaBytes() { + m.sora_storage_quota_bytes = nil + m.addsora_storage_quota_bytes = nil +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (m *GroupMutation) SetClaudeCodeOnly(b bool) { m.claude_code_only = &b @@ -8881,7 +10191,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 25) + fields := make([]string, 0, 30) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -8933,6 +10243,21 @@ func (m *GroupMutation) Fields() []string { if m.image_price_4k != nil { fields = append(fields, group.FieldImagePrice4k) } + if m.sora_image_price_360 != nil { + fields = append(fields, group.FieldSoraImagePrice360) + } + if m.sora_image_price_540 != nil { + fields = append(fields, group.FieldSoraImagePrice540) + } + if m.sora_video_price_per_request != nil { + fields = append(fields, group.FieldSoraVideoPricePerRequest) + } + if m.sora_video_price_per_request_hd != nil { + fields = append(fields, group.FieldSoraVideoPricePerRequestHd) + } + if m.sora_storage_quota_bytes != nil { + fields = append(fields, group.FieldSoraStorageQuotaBytes) + } if m.claude_code_only != nil { fields = append(fields, group.FieldClaudeCodeOnly) } @@ -8999,6 +10324,16 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.ImagePrice2k() case group.FieldImagePrice4k: return m.ImagePrice4k() + case group.FieldSoraImagePrice360: + return m.SoraImagePrice360() + case group.FieldSoraImagePrice540: + return m.SoraImagePrice540() + case group.FieldSoraVideoPricePerRequest: + return m.SoraVideoPricePerRequest() + case group.FieldSoraVideoPricePerRequestHd: + return m.SoraVideoPricePerRequestHd() + case group.FieldSoraStorageQuotaBytes: + return m.SoraStorageQuotaBytes() case group.FieldClaudeCodeOnly: return m.ClaudeCodeOnly() case group.FieldFallbackGroupID: @@ -9058,6 +10393,16 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldImagePrice2k(ctx) case group.FieldImagePrice4k: return m.OldImagePrice4k(ctx) + case group.FieldSoraImagePrice360: + return m.OldSoraImagePrice360(ctx) + case group.FieldSoraImagePrice540: + return m.OldSoraImagePrice540(ctx) + case group.FieldSoraVideoPricePerRequest: + return m.OldSoraVideoPricePerRequest(ctx) + case group.FieldSoraVideoPricePerRequestHd: + return m.OldSoraVideoPricePerRequestHd(ctx) + case group.FieldSoraStorageQuotaBytes: + return m.OldSoraStorageQuotaBytes(ctx) case group.FieldClaudeCodeOnly: return m.OldClaudeCodeOnly(ctx) case group.FieldFallbackGroupID: @@ -9202,6 +10547,41 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetImagePrice4k(v) return nil + case group.FieldSoraImagePrice360: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraImagePrice360(v) + return nil + case group.FieldSoraImagePrice540: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraImagePrice540(v) + return nil + case group.FieldSoraVideoPricePerRequest: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraVideoPricePerRequest(v) + return nil + case group.FieldSoraVideoPricePerRequestHd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraVideoPricePerRequestHd(v) + return nil + case group.FieldSoraStorageQuotaBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraStorageQuotaBytes(v) + return nil case group.FieldClaudeCodeOnly: v, ok := value.(bool) if !ok { @@ -9290,6 +10670,21 @@ func (m *GroupMutation) AddedFields() []string { if m.addimage_price_4k != nil { fields = append(fields, group.FieldImagePrice4k) } + if m.addsora_image_price_360 != nil { + fields = append(fields, group.FieldSoraImagePrice360) + } + if m.addsora_image_price_540 != nil { + fields = append(fields, group.FieldSoraImagePrice540) + } + if m.addsora_video_price_per_request != nil { + fields = append(fields, group.FieldSoraVideoPricePerRequest) + } + if m.addsora_video_price_per_request_hd != nil { + fields = append(fields, group.FieldSoraVideoPricePerRequestHd) + } + if m.addsora_storage_quota_bytes != nil { + fields = append(fields, group.FieldSoraStorageQuotaBytes) + } if m.addfallback_group_id != nil { fields = append(fields, group.FieldFallbackGroupID) } @@ -9323,6 +10718,16 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedImagePrice2k() case group.FieldImagePrice4k: return m.AddedImagePrice4k() + case group.FieldSoraImagePrice360: + return m.AddedSoraImagePrice360() + case group.FieldSoraImagePrice540: + return m.AddedSoraImagePrice540() + case group.FieldSoraVideoPricePerRequest: + return m.AddedSoraVideoPricePerRequest() + case group.FieldSoraVideoPricePerRequestHd: + return m.AddedSoraVideoPricePerRequestHd() + case group.FieldSoraStorageQuotaBytes: + return m.AddedSoraStorageQuotaBytes() case group.FieldFallbackGroupID: return m.AddedFallbackGroupID() case group.FieldFallbackGroupIDOnInvalidRequest: @@ -9394,6 +10799,41 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddImagePrice4k(v) return nil + case group.FieldSoraImagePrice360: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraImagePrice360(v) + return nil + case group.FieldSoraImagePrice540: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraImagePrice540(v) + return nil + case group.FieldSoraVideoPricePerRequest: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraVideoPricePerRequest(v) + return nil + case group.FieldSoraVideoPricePerRequestHd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraVideoPricePerRequestHd(v) + return nil + case group.FieldSoraStorageQuotaBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraStorageQuotaBytes(v) + return nil case group.FieldFallbackGroupID: v, ok := value.(int64) if !ok { @@ -9447,6 +10887,18 @@ func (m *GroupMutation) ClearedFields() []string { if m.FieldCleared(group.FieldImagePrice4k) { fields = append(fields, group.FieldImagePrice4k) } + if m.FieldCleared(group.FieldSoraImagePrice360) { + fields = append(fields, group.FieldSoraImagePrice360) + } + if m.FieldCleared(group.FieldSoraImagePrice540) { + fields = append(fields, group.FieldSoraImagePrice540) + } + if m.FieldCleared(group.FieldSoraVideoPricePerRequest) { + fields = append(fields, group.FieldSoraVideoPricePerRequest) + } + if m.FieldCleared(group.FieldSoraVideoPricePerRequestHd) { + fields = append(fields, group.FieldSoraVideoPricePerRequestHd) + } if m.FieldCleared(group.FieldFallbackGroupID) { fields = append(fields, group.FieldFallbackGroupID) } @@ -9494,6 +10946,18 @@ func (m *GroupMutation) ClearField(name string) error { case group.FieldImagePrice4k: m.ClearImagePrice4k() return nil + case group.FieldSoraImagePrice360: + m.ClearSoraImagePrice360() + return nil + case group.FieldSoraImagePrice540: + m.ClearSoraImagePrice540() + return nil + case group.FieldSoraVideoPricePerRequest: + m.ClearSoraVideoPricePerRequest() + return nil + case group.FieldSoraVideoPricePerRequestHd: + m.ClearSoraVideoPricePerRequestHd() + return nil case group.FieldFallbackGroupID: m.ClearFallbackGroupID() return nil @@ -9562,6 +11026,21 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldImagePrice4k: m.ResetImagePrice4k() return nil + case group.FieldSoraImagePrice360: + m.ResetSoraImagePrice360() + return nil + case group.FieldSoraImagePrice540: + m.ResetSoraImagePrice540() + return nil + case group.FieldSoraVideoPricePerRequest: + m.ResetSoraVideoPricePerRequest() + return nil + case group.FieldSoraVideoPricePerRequestHd: + m.ResetSoraVideoPricePerRequestHd() + return nil + case group.FieldSoraStorageQuotaBytes: + m.ResetSoraStorageQuotaBytes() + return nil case group.FieldClaudeCodeOnly: m.ResetClaudeCodeOnly() return nil @@ -9804,6 +11283,988 @@ func (m *GroupMutation) ResetEdge(name string) error { return fmt.Errorf("unknown Group edge %s", name) } +// IdempotencyRecordMutation represents an operation that mutates the IdempotencyRecord nodes in the graph. +type IdempotencyRecordMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + scope *string + idempotency_key_hash *string + request_fingerprint *string + status *string + response_status *int + addresponse_status *int + response_body *string + error_reason *string + locked_until *time.Time + expires_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*IdempotencyRecord, error) + predicates []predicate.IdempotencyRecord +} + +var _ ent.Mutation = (*IdempotencyRecordMutation)(nil) + +// idempotencyrecordOption allows management of the mutation configuration using functional options. +type idempotencyrecordOption func(*IdempotencyRecordMutation) + +// newIdempotencyRecordMutation creates new mutation for the IdempotencyRecord entity. +func newIdempotencyRecordMutation(c config, op Op, opts ...idempotencyrecordOption) *IdempotencyRecordMutation { + m := &IdempotencyRecordMutation{ + config: c, + op: op, + typ: TypeIdempotencyRecord, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withIdempotencyRecordID sets the ID field of the mutation. +func withIdempotencyRecordID(id int64) idempotencyrecordOption { + return func(m *IdempotencyRecordMutation) { + var ( + err error + once sync.Once + value *IdempotencyRecord + ) + m.oldValue = func(ctx context.Context) (*IdempotencyRecord, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().IdempotencyRecord.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withIdempotencyRecord sets the old IdempotencyRecord of the mutation. +func withIdempotencyRecord(node *IdempotencyRecord) idempotencyrecordOption { + return func(m *IdempotencyRecordMutation) { + m.oldValue = func(context.Context) (*IdempotencyRecord, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m IdempotencyRecordMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m IdempotencyRecordMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *IdempotencyRecordMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *IdempotencyRecordMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().IdempotencyRecord.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *IdempotencyRecordMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *IdempotencyRecordMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *IdempotencyRecordMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *IdempotencyRecordMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *IdempotencyRecordMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *IdempotencyRecordMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetScope sets the "scope" field. +func (m *IdempotencyRecordMutation) SetScope(s string) { + m.scope = &s +} + +// Scope returns the value of the "scope" field in the mutation. +func (m *IdempotencyRecordMutation) Scope() (r string, exists bool) { + v := m.scope + if v == nil { + return + } + return *v, true +} + +// OldScope returns the old "scope" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldScope(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldScope is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldScope requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldScope: %w", err) + } + return oldValue.Scope, nil +} + +// ResetScope resets all changes to the "scope" field. +func (m *IdempotencyRecordMutation) ResetScope() { + m.scope = nil +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (m *IdempotencyRecordMutation) SetIdempotencyKeyHash(s string) { + m.idempotency_key_hash = &s +} + +// IdempotencyKeyHash returns the value of the "idempotency_key_hash" field in the mutation. +func (m *IdempotencyRecordMutation) IdempotencyKeyHash() (r string, exists bool) { + v := m.idempotency_key_hash + if v == nil { + return + } + return *v, true +} + +// OldIdempotencyKeyHash returns the old "idempotency_key_hash" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldIdempotencyKeyHash(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIdempotencyKeyHash is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIdempotencyKeyHash requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIdempotencyKeyHash: %w", err) + } + return oldValue.IdempotencyKeyHash, nil +} + +// ResetIdempotencyKeyHash resets all changes to the "idempotency_key_hash" field. +func (m *IdempotencyRecordMutation) ResetIdempotencyKeyHash() { + m.idempotency_key_hash = nil +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (m *IdempotencyRecordMutation) SetRequestFingerprint(s string) { + m.request_fingerprint = &s +} + +// RequestFingerprint returns the value of the "request_fingerprint" field in the mutation. +func (m *IdempotencyRecordMutation) RequestFingerprint() (r string, exists bool) { + v := m.request_fingerprint + if v == nil { + return + } + return *v, true +} + +// OldRequestFingerprint returns the old "request_fingerprint" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldRequestFingerprint(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRequestFingerprint is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRequestFingerprint requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRequestFingerprint: %w", err) + } + return oldValue.RequestFingerprint, nil +} + +// ResetRequestFingerprint resets all changes to the "request_fingerprint" field. +func (m *IdempotencyRecordMutation) ResetRequestFingerprint() { + m.request_fingerprint = nil +} + +// SetStatus sets the "status" field. +func (m *IdempotencyRecordMutation) SetStatus(s string) { + m.status = &s +} + +// Status returns the value of the "status" field in the mutation. +func (m *IdempotencyRecordMutation) Status() (r string, exists bool) { + v := m.status + if v == nil { + return + } + return *v, true +} + +// OldStatus returns the old "status" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldStatus(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) + } + return oldValue.Status, nil +} + +// ResetStatus resets all changes to the "status" field. +func (m *IdempotencyRecordMutation) ResetStatus() { + m.status = nil +} + +// SetResponseStatus sets the "response_status" field. +func (m *IdempotencyRecordMutation) SetResponseStatus(i int) { + m.response_status = &i + m.addresponse_status = nil +} + +// ResponseStatus returns the value of the "response_status" field in the mutation. +func (m *IdempotencyRecordMutation) ResponseStatus() (r int, exists bool) { + v := m.response_status + if v == nil { + return + } + return *v, true +} + +// OldResponseStatus returns the old "response_status" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldResponseStatus(ctx context.Context) (v *int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldResponseStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldResponseStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldResponseStatus: %w", err) + } + return oldValue.ResponseStatus, nil +} + +// AddResponseStatus adds i to the "response_status" field. +func (m *IdempotencyRecordMutation) AddResponseStatus(i int) { + if m.addresponse_status != nil { + *m.addresponse_status += i + } else { + m.addresponse_status = &i + } +} + +// AddedResponseStatus returns the value that was added to the "response_status" field in this mutation. +func (m *IdempotencyRecordMutation) AddedResponseStatus() (r int, exists bool) { + v := m.addresponse_status + if v == nil { + return + } + return *v, true +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (m *IdempotencyRecordMutation) ClearResponseStatus() { + m.response_status = nil + m.addresponse_status = nil + m.clearedFields[idempotencyrecord.FieldResponseStatus] = struct{}{} +} + +// ResponseStatusCleared returns if the "response_status" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) ResponseStatusCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldResponseStatus] + return ok +} + +// ResetResponseStatus resets all changes to the "response_status" field. +func (m *IdempotencyRecordMutation) ResetResponseStatus() { + m.response_status = nil + m.addresponse_status = nil + delete(m.clearedFields, idempotencyrecord.FieldResponseStatus) +} + +// SetResponseBody sets the "response_body" field. +func (m *IdempotencyRecordMutation) SetResponseBody(s string) { + m.response_body = &s +} + +// ResponseBody returns the value of the "response_body" field in the mutation. +func (m *IdempotencyRecordMutation) ResponseBody() (r string, exists bool) { + v := m.response_body + if v == nil { + return + } + return *v, true +} + +// OldResponseBody returns the old "response_body" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldResponseBody(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldResponseBody is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldResponseBody requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldResponseBody: %w", err) + } + return oldValue.ResponseBody, nil +} + +// ClearResponseBody clears the value of the "response_body" field. +func (m *IdempotencyRecordMutation) ClearResponseBody() { + m.response_body = nil + m.clearedFields[idempotencyrecord.FieldResponseBody] = struct{}{} +} + +// ResponseBodyCleared returns if the "response_body" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) ResponseBodyCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldResponseBody] + return ok +} + +// ResetResponseBody resets all changes to the "response_body" field. +func (m *IdempotencyRecordMutation) ResetResponseBody() { + m.response_body = nil + delete(m.clearedFields, idempotencyrecord.FieldResponseBody) +} + +// SetErrorReason sets the "error_reason" field. +func (m *IdempotencyRecordMutation) SetErrorReason(s string) { + m.error_reason = &s +} + +// ErrorReason returns the value of the "error_reason" field in the mutation. +func (m *IdempotencyRecordMutation) ErrorReason() (r string, exists bool) { + v := m.error_reason + if v == nil { + return + } + return *v, true +} + +// OldErrorReason returns the old "error_reason" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldErrorReason(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldErrorReason is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldErrorReason requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldErrorReason: %w", err) + } + return oldValue.ErrorReason, nil +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (m *IdempotencyRecordMutation) ClearErrorReason() { + m.error_reason = nil + m.clearedFields[idempotencyrecord.FieldErrorReason] = struct{}{} +} + +// ErrorReasonCleared returns if the "error_reason" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) ErrorReasonCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldErrorReason] + return ok +} + +// ResetErrorReason resets all changes to the "error_reason" field. +func (m *IdempotencyRecordMutation) ResetErrorReason() { + m.error_reason = nil + delete(m.clearedFields, idempotencyrecord.FieldErrorReason) +} + +// SetLockedUntil sets the "locked_until" field. +func (m *IdempotencyRecordMutation) SetLockedUntil(t time.Time) { + m.locked_until = &t +} + +// LockedUntil returns the value of the "locked_until" field in the mutation. +func (m *IdempotencyRecordMutation) LockedUntil() (r time.Time, exists bool) { + v := m.locked_until + if v == nil { + return + } + return *v, true +} + +// OldLockedUntil returns the old "locked_until" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldLockedUntil(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLockedUntil is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLockedUntil requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLockedUntil: %w", err) + } + return oldValue.LockedUntil, nil +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (m *IdempotencyRecordMutation) ClearLockedUntil() { + m.locked_until = nil + m.clearedFields[idempotencyrecord.FieldLockedUntil] = struct{}{} +} + +// LockedUntilCleared returns if the "locked_until" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) LockedUntilCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldLockedUntil] + return ok +} + +// ResetLockedUntil resets all changes to the "locked_until" field. +func (m *IdempotencyRecordMutation) ResetLockedUntil() { + m.locked_until = nil + delete(m.clearedFields, idempotencyrecord.FieldLockedUntil) +} + +// SetExpiresAt sets the "expires_at" field. +func (m *IdempotencyRecordMutation) SetExpiresAt(t time.Time) { + m.expires_at = &t +} + +// ExpiresAt returns the value of the "expires_at" field in the mutation. +func (m *IdempotencyRecordMutation) ExpiresAt() (r time.Time, exists bool) { + v := m.expires_at + if v == nil { + return + } + return *v, true +} + +// OldExpiresAt returns the old "expires_at" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldExpiresAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) + } + return oldValue.ExpiresAt, nil +} + +// ResetExpiresAt resets all changes to the "expires_at" field. +func (m *IdempotencyRecordMutation) ResetExpiresAt() { + m.expires_at = nil +} + +// Where appends a list predicates to the IdempotencyRecordMutation builder. +func (m *IdempotencyRecordMutation) Where(ps ...predicate.IdempotencyRecord) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the IdempotencyRecordMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *IdempotencyRecordMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.IdempotencyRecord, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *IdempotencyRecordMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *IdempotencyRecordMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (IdempotencyRecord). +func (m *IdempotencyRecordMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *IdempotencyRecordMutation) Fields() []string { + fields := make([]string, 0, 11) + if m.created_at != nil { + fields = append(fields, idempotencyrecord.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, idempotencyrecord.FieldUpdatedAt) + } + if m.scope != nil { + fields = append(fields, idempotencyrecord.FieldScope) + } + if m.idempotency_key_hash != nil { + fields = append(fields, idempotencyrecord.FieldIdempotencyKeyHash) + } + if m.request_fingerprint != nil { + fields = append(fields, idempotencyrecord.FieldRequestFingerprint) + } + if m.status != nil { + fields = append(fields, idempotencyrecord.FieldStatus) + } + if m.response_status != nil { + fields = append(fields, idempotencyrecord.FieldResponseStatus) + } + if m.response_body != nil { + fields = append(fields, idempotencyrecord.FieldResponseBody) + } + if m.error_reason != nil { + fields = append(fields, idempotencyrecord.FieldErrorReason) + } + if m.locked_until != nil { + fields = append(fields, idempotencyrecord.FieldLockedUntil) + } + if m.expires_at != nil { + fields = append(fields, idempotencyrecord.FieldExpiresAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *IdempotencyRecordMutation) Field(name string) (ent.Value, bool) { + switch name { + case idempotencyrecord.FieldCreatedAt: + return m.CreatedAt() + case idempotencyrecord.FieldUpdatedAt: + return m.UpdatedAt() + case idempotencyrecord.FieldScope: + return m.Scope() + case idempotencyrecord.FieldIdempotencyKeyHash: + return m.IdempotencyKeyHash() + case idempotencyrecord.FieldRequestFingerprint: + return m.RequestFingerprint() + case idempotencyrecord.FieldStatus: + return m.Status() + case idempotencyrecord.FieldResponseStatus: + return m.ResponseStatus() + case idempotencyrecord.FieldResponseBody: + return m.ResponseBody() + case idempotencyrecord.FieldErrorReason: + return m.ErrorReason() + case idempotencyrecord.FieldLockedUntil: + return m.LockedUntil() + case idempotencyrecord.FieldExpiresAt: + return m.ExpiresAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *IdempotencyRecordMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case idempotencyrecord.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case idempotencyrecord.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case idempotencyrecord.FieldScope: + return m.OldScope(ctx) + case idempotencyrecord.FieldIdempotencyKeyHash: + return m.OldIdempotencyKeyHash(ctx) + case idempotencyrecord.FieldRequestFingerprint: + return m.OldRequestFingerprint(ctx) + case idempotencyrecord.FieldStatus: + return m.OldStatus(ctx) + case idempotencyrecord.FieldResponseStatus: + return m.OldResponseStatus(ctx) + case idempotencyrecord.FieldResponseBody: + return m.OldResponseBody(ctx) + case idempotencyrecord.FieldErrorReason: + return m.OldErrorReason(ctx) + case idempotencyrecord.FieldLockedUntil: + return m.OldLockedUntil(ctx) + case idempotencyrecord.FieldExpiresAt: + return m.OldExpiresAt(ctx) + } + return nil, fmt.Errorf("unknown IdempotencyRecord field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdempotencyRecordMutation) SetField(name string, value ent.Value) error { + switch name { + case idempotencyrecord.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case idempotencyrecord.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case idempotencyrecord.FieldScope: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetScope(v) + return nil + case idempotencyrecord.FieldIdempotencyKeyHash: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIdempotencyKeyHash(v) + return nil + case idempotencyrecord.FieldRequestFingerprint: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequestFingerprint(v) + return nil + case idempotencyrecord.FieldStatus: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil + case idempotencyrecord.FieldResponseStatus: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResponseStatus(v) + return nil + case idempotencyrecord.FieldResponseBody: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResponseBody(v) + return nil + case idempotencyrecord.FieldErrorReason: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetErrorReason(v) + return nil + case idempotencyrecord.FieldLockedUntil: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLockedUntil(v) + return nil + case idempotencyrecord.FieldExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiresAt(v) + return nil + } + return fmt.Errorf("unknown IdempotencyRecord field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *IdempotencyRecordMutation) AddedFields() []string { + var fields []string + if m.addresponse_status != nil { + fields = append(fields, idempotencyrecord.FieldResponseStatus) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *IdempotencyRecordMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case idempotencyrecord.FieldResponseStatus: + return m.AddedResponseStatus() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdempotencyRecordMutation) AddField(name string, value ent.Value) error { + switch name { + case idempotencyrecord.FieldResponseStatus: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddResponseStatus(v) + return nil + } + return fmt.Errorf("unknown IdempotencyRecord numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *IdempotencyRecordMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(idempotencyrecord.FieldResponseStatus) { + fields = append(fields, idempotencyrecord.FieldResponseStatus) + } + if m.FieldCleared(idempotencyrecord.FieldResponseBody) { + fields = append(fields, idempotencyrecord.FieldResponseBody) + } + if m.FieldCleared(idempotencyrecord.FieldErrorReason) { + fields = append(fields, idempotencyrecord.FieldErrorReason) + } + if m.FieldCleared(idempotencyrecord.FieldLockedUntil) { + fields = append(fields, idempotencyrecord.FieldLockedUntil) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *IdempotencyRecordMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *IdempotencyRecordMutation) ClearField(name string) error { + switch name { + case idempotencyrecord.FieldResponseStatus: + m.ClearResponseStatus() + return nil + case idempotencyrecord.FieldResponseBody: + m.ClearResponseBody() + return nil + case idempotencyrecord.FieldErrorReason: + m.ClearErrorReason() + return nil + case idempotencyrecord.FieldLockedUntil: + m.ClearLockedUntil() + return nil + } + return fmt.Errorf("unknown IdempotencyRecord nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *IdempotencyRecordMutation) ResetField(name string) error { + switch name { + case idempotencyrecord.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case idempotencyrecord.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case idempotencyrecord.FieldScope: + m.ResetScope() + return nil + case idempotencyrecord.FieldIdempotencyKeyHash: + m.ResetIdempotencyKeyHash() + return nil + case idempotencyrecord.FieldRequestFingerprint: + m.ResetRequestFingerprint() + return nil + case idempotencyrecord.FieldStatus: + m.ResetStatus() + return nil + case idempotencyrecord.FieldResponseStatus: + m.ResetResponseStatus() + return nil + case idempotencyrecord.FieldResponseBody: + m.ResetResponseBody() + return nil + case idempotencyrecord.FieldErrorReason: + m.ResetErrorReason() + return nil + case idempotencyrecord.FieldLockedUntil: + m.ResetLockedUntil() + return nil + case idempotencyrecord.FieldExpiresAt: + m.ResetExpiresAt() + return nil + } + return fmt.Errorf("unknown IdempotencyRecord field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *IdempotencyRecordMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *IdempotencyRecordMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *IdempotencyRecordMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *IdempotencyRecordMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *IdempotencyRecordMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *IdempotencyRecordMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *IdempotencyRecordMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown IdempotencyRecord unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *IdempotencyRecordMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown IdempotencyRecord edge %s", name) +} + // PromoCodeMutation represents an operation that mutates the PromoCode nodes in the graph. type PromoCodeMutation struct { config @@ -13496,6 +15957,494 @@ func (m *RedeemCodeMutation) ResetEdge(name string) error { return fmt.Errorf("unknown RedeemCode edge %s", name) } +// SecuritySecretMutation represents an operation that mutates the SecuritySecret nodes in the graph. +type SecuritySecretMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + key *string + value *string + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*SecuritySecret, error) + predicates []predicate.SecuritySecret +} + +var _ ent.Mutation = (*SecuritySecretMutation)(nil) + +// securitysecretOption allows management of the mutation configuration using functional options. +type securitysecretOption func(*SecuritySecretMutation) + +// newSecuritySecretMutation creates new mutation for the SecuritySecret entity. +func newSecuritySecretMutation(c config, op Op, opts ...securitysecretOption) *SecuritySecretMutation { + m := &SecuritySecretMutation{ + config: c, + op: op, + typ: TypeSecuritySecret, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withSecuritySecretID sets the ID field of the mutation. +func withSecuritySecretID(id int64) securitysecretOption { + return func(m *SecuritySecretMutation) { + var ( + err error + once sync.Once + value *SecuritySecret + ) + m.oldValue = func(ctx context.Context) (*SecuritySecret, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().SecuritySecret.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withSecuritySecret sets the old SecuritySecret of the mutation. +func withSecuritySecret(node *SecuritySecret) securitysecretOption { + return func(m *SecuritySecretMutation) { + m.oldValue = func(context.Context) (*SecuritySecret, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m SecuritySecretMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m SecuritySecretMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *SecuritySecretMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *SecuritySecretMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().SecuritySecret.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *SecuritySecretMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *SecuritySecretMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the SecuritySecret entity. +// If the SecuritySecret object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SecuritySecretMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *SecuritySecretMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *SecuritySecretMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *SecuritySecretMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the SecuritySecret entity. +// If the SecuritySecret object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SecuritySecretMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *SecuritySecretMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetKey sets the "key" field. +func (m *SecuritySecretMutation) SetKey(s string) { + m.key = &s +} + +// Key returns the value of the "key" field in the mutation. +func (m *SecuritySecretMutation) Key() (r string, exists bool) { + v := m.key + if v == nil { + return + } + return *v, true +} + +// OldKey returns the old "key" field's value of the SecuritySecret entity. +// If the SecuritySecret object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SecuritySecretMutation) OldKey(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldKey: %w", err) + } + return oldValue.Key, nil +} + +// ResetKey resets all changes to the "key" field. +func (m *SecuritySecretMutation) ResetKey() { + m.key = nil +} + +// SetValue sets the "value" field. +func (m *SecuritySecretMutation) SetValue(s string) { + m.value = &s +} + +// Value returns the value of the "value" field in the mutation. +func (m *SecuritySecretMutation) Value() (r string, exists bool) { + v := m.value + if v == nil { + return + } + return *v, true +} + +// OldValue returns the old "value" field's value of the SecuritySecret entity. +// If the SecuritySecret object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SecuritySecretMutation) OldValue(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldValue is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldValue requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldValue: %w", err) + } + return oldValue.Value, nil +} + +// ResetValue resets all changes to the "value" field. +func (m *SecuritySecretMutation) ResetValue() { + m.value = nil +} + +// Where appends a list predicates to the SecuritySecretMutation builder. +func (m *SecuritySecretMutation) Where(ps ...predicate.SecuritySecret) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the SecuritySecretMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *SecuritySecretMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.SecuritySecret, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *SecuritySecretMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *SecuritySecretMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (SecuritySecret). +func (m *SecuritySecretMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *SecuritySecretMutation) Fields() []string { + fields := make([]string, 0, 4) + if m.created_at != nil { + fields = append(fields, securitysecret.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, securitysecret.FieldUpdatedAt) + } + if m.key != nil { + fields = append(fields, securitysecret.FieldKey) + } + if m.value != nil { + fields = append(fields, securitysecret.FieldValue) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *SecuritySecretMutation) Field(name string) (ent.Value, bool) { + switch name { + case securitysecret.FieldCreatedAt: + return m.CreatedAt() + case securitysecret.FieldUpdatedAt: + return m.UpdatedAt() + case securitysecret.FieldKey: + return m.Key() + case securitysecret.FieldValue: + return m.Value() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *SecuritySecretMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case securitysecret.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case securitysecret.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case securitysecret.FieldKey: + return m.OldKey(ctx) + case securitysecret.FieldValue: + return m.OldValue(ctx) + } + return nil, fmt.Errorf("unknown SecuritySecret field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *SecuritySecretMutation) SetField(name string, value ent.Value) error { + switch name { + case securitysecret.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case securitysecret.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case securitysecret.FieldKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetKey(v) + return nil + case securitysecret.FieldValue: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetValue(v) + return nil + } + return fmt.Errorf("unknown SecuritySecret field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *SecuritySecretMutation) AddedFields() []string { + return nil +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *SecuritySecretMutation) AddedField(name string) (ent.Value, bool) { + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *SecuritySecretMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown SecuritySecret numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *SecuritySecretMutation) ClearedFields() []string { + return nil +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *SecuritySecretMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *SecuritySecretMutation) ClearField(name string) error { + return fmt.Errorf("unknown SecuritySecret nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *SecuritySecretMutation) ResetField(name string) error { + switch name { + case securitysecret.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case securitysecret.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case securitysecret.FieldKey: + m.ResetKey() + return nil + case securitysecret.FieldValue: + m.ResetValue() + return nil + } + return fmt.Errorf("unknown SecuritySecret field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *SecuritySecretMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *SecuritySecretMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *SecuritySecretMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *SecuritySecretMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *SecuritySecretMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *SecuritySecretMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *SecuritySecretMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown SecuritySecret unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *SecuritySecretMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown SecuritySecret edge %s", name) +} + // SettingMutation represents an operation that mutates the Setting nodes in the graph. type SettingMutation struct { config @@ -15061,6 +18010,8 @@ type UsageLogMutation struct { image_count *int addimage_count *int image_size *string + media_type *string + cache_ttl_overridden *bool created_at *time.Time clearedFields map[string]struct{} user *int64 @@ -16687,6 +19638,91 @@ func (m *UsageLogMutation) ResetImageSize() { delete(m.clearedFields, usagelog.FieldImageSize) } +// SetMediaType sets the "media_type" field. +func (m *UsageLogMutation) SetMediaType(s string) { + m.media_type = &s +} + +// MediaType returns the value of the "media_type" field in the mutation. +func (m *UsageLogMutation) MediaType() (r string, exists bool) { + v := m.media_type + if v == nil { + return + } + return *v, true +} + +// OldMediaType returns the old "media_type" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldMediaType(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMediaType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMediaType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMediaType: %w", err) + } + return oldValue.MediaType, nil +} + +// ClearMediaType clears the value of the "media_type" field. +func (m *UsageLogMutation) ClearMediaType() { + m.media_type = nil + m.clearedFields[usagelog.FieldMediaType] = struct{}{} +} + +// MediaTypeCleared returns if the "media_type" field was cleared in this mutation. +func (m *UsageLogMutation) MediaTypeCleared() bool { + _, ok := m.clearedFields[usagelog.FieldMediaType] + return ok +} + +// ResetMediaType resets all changes to the "media_type" field. +func (m *UsageLogMutation) ResetMediaType() { + m.media_type = nil + delete(m.clearedFields, usagelog.FieldMediaType) +} + +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) { + m.cache_ttl_overridden = &b +} + +// CacheTTLOverridden returns the value of the "cache_ttl_overridden" field in the mutation. +func (m *UsageLogMutation) CacheTTLOverridden() (r bool, exists bool) { + v := m.cache_ttl_overridden + if v == nil { + return + } + return *v, true +} + +// OldCacheTTLOverridden returns the old "cache_ttl_overridden" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldCacheTTLOverridden(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCacheTTLOverridden is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCacheTTLOverridden requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCacheTTLOverridden: %w", err) + } + return oldValue.CacheTTLOverridden, nil +} + +// ResetCacheTTLOverridden resets all changes to the "cache_ttl_overridden" field. +func (m *UsageLogMutation) ResetCacheTTLOverridden() { + m.cache_ttl_overridden = nil +} + // SetCreatedAt sets the "created_at" field. func (m *UsageLogMutation) SetCreatedAt(t time.Time) { m.created_at = &t @@ -16892,7 +19928,7 @@ func (m *UsageLogMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UsageLogMutation) Fields() []string { - fields := make([]string, 0, 30) + fields := make([]string, 0, 32) if m.user != nil { fields = append(fields, usagelog.FieldUserID) } @@ -16980,6 +20016,12 @@ func (m *UsageLogMutation) Fields() []string { if m.image_size != nil { fields = append(fields, usagelog.FieldImageSize) } + if m.media_type != nil { + fields = append(fields, usagelog.FieldMediaType) + } + if m.cache_ttl_overridden != nil { + fields = append(fields, usagelog.FieldCacheTTLOverridden) + } if m.created_at != nil { fields = append(fields, usagelog.FieldCreatedAt) } @@ -17049,6 +20091,10 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { return m.ImageCount() case usagelog.FieldImageSize: return m.ImageSize() + case usagelog.FieldMediaType: + return m.MediaType() + case usagelog.FieldCacheTTLOverridden: + return m.CacheTTLOverridden() case usagelog.FieldCreatedAt: return m.CreatedAt() } @@ -17118,6 +20164,10 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value return m.OldImageCount(ctx) case usagelog.FieldImageSize: return m.OldImageSize(ctx) + case usagelog.FieldMediaType: + return m.OldMediaType(ctx) + case usagelog.FieldCacheTTLOverridden: + return m.OldCacheTTLOverridden(ctx) case usagelog.FieldCreatedAt: return m.OldCreatedAt(ctx) } @@ -17332,6 +20382,20 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { } m.SetImageSize(v) return nil + case usagelog.FieldMediaType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMediaType(v) + return nil + case usagelog.FieldCacheTTLOverridden: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCacheTTLOverridden(v) + return nil case usagelog.FieldCreatedAt: v, ok := value.(time.Time) if !ok { @@ -17612,6 +20676,9 @@ func (m *UsageLogMutation) ClearedFields() []string { if m.FieldCleared(usagelog.FieldImageSize) { fields = append(fields, usagelog.FieldImageSize) } + if m.FieldCleared(usagelog.FieldMediaType) { + fields = append(fields, usagelog.FieldMediaType) + } return fields } @@ -17650,6 +20717,9 @@ func (m *UsageLogMutation) ClearField(name string) error { case usagelog.FieldImageSize: m.ClearImageSize() return nil + case usagelog.FieldMediaType: + m.ClearMediaType() + return nil } return fmt.Errorf("unknown UsageLog nullable field %s", name) } @@ -17745,6 +20815,12 @@ func (m *UsageLogMutation) ResetField(name string) error { case usagelog.FieldImageSize: m.ResetImageSize() return nil + case usagelog.FieldMediaType: + m.ResetMediaType() + return nil + case usagelog.FieldCacheTTLOverridden: + m.ResetCacheTTLOverridden() + return nil case usagelog.FieldCreatedAt: m.ResetCreatedAt() return nil @@ -17920,6 +20996,10 @@ type UserMutation struct { totp_secret_encrypted *string totp_enabled *bool totp_enabled_at *time.Time + sora_storage_quota_bytes *int64 + addsora_storage_quota_bytes *int64 + sora_storage_used_bytes *int64 + addsora_storage_used_bytes *int64 clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -18634,6 +21714,118 @@ func (m *UserMutation) ResetTotpEnabledAt() { delete(m.clearedFields, user.FieldTotpEnabledAt) } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (m *UserMutation) SetSoraStorageQuotaBytes(i int64) { + m.sora_storage_quota_bytes = &i + m.addsora_storage_quota_bytes = nil +} + +// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation. +func (m *UserMutation) SoraStorageQuotaBytes() (r int64, exists bool) { + v := m.sora_storage_quota_bytes + if v == nil { + return + } + return *v, true +} + +// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err) + } + return oldValue.SoraStorageQuotaBytes, nil +} + +// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field. +func (m *UserMutation) AddSoraStorageQuotaBytes(i int64) { + if m.addsora_storage_quota_bytes != nil { + *m.addsora_storage_quota_bytes += i + } else { + m.addsora_storage_quota_bytes = &i + } +} + +// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation. +func (m *UserMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) { + v := m.addsora_storage_quota_bytes + if v == nil { + return + } + return *v, true +} + +// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field. +func (m *UserMutation) ResetSoraStorageQuotaBytes() { + m.sora_storage_quota_bytes = nil + m.addsora_storage_quota_bytes = nil +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (m *UserMutation) SetSoraStorageUsedBytes(i int64) { + m.sora_storage_used_bytes = &i + m.addsora_storage_used_bytes = nil +} + +// SoraStorageUsedBytes returns the value of the "sora_storage_used_bytes" field in the mutation. +func (m *UserMutation) SoraStorageUsedBytes() (r int64, exists bool) { + v := m.sora_storage_used_bytes + if v == nil { + return + } + return *v, true +} + +// OldSoraStorageUsedBytes returns the old "sora_storage_used_bytes" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldSoraStorageUsedBytes(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraStorageUsedBytes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraStorageUsedBytes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraStorageUsedBytes: %w", err) + } + return oldValue.SoraStorageUsedBytes, nil +} + +// AddSoraStorageUsedBytes adds i to the "sora_storage_used_bytes" field. +func (m *UserMutation) AddSoraStorageUsedBytes(i int64) { + if m.addsora_storage_used_bytes != nil { + *m.addsora_storage_used_bytes += i + } else { + m.addsora_storage_used_bytes = &i + } +} + +// AddedSoraStorageUsedBytes returns the value that was added to the "sora_storage_used_bytes" field in this mutation. +func (m *UserMutation) AddedSoraStorageUsedBytes() (r int64, exists bool) { + v := m.addsora_storage_used_bytes + if v == nil { + return + } + return *v, true +} + +// ResetSoraStorageUsedBytes resets all changes to the "sora_storage_used_bytes" field. +func (m *UserMutation) ResetSoraStorageUsedBytes() { + m.sora_storage_used_bytes = nil + m.addsora_storage_used_bytes = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *UserMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -19154,7 +22346,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 14) + fields := make([]string, 0, 16) if m.created_at != nil { fields = append(fields, user.FieldCreatedAt) } @@ -19197,6 +22389,12 @@ func (m *UserMutation) Fields() []string { if m.totp_enabled_at != nil { fields = append(fields, user.FieldTotpEnabledAt) } + if m.sora_storage_quota_bytes != nil { + fields = append(fields, user.FieldSoraStorageQuotaBytes) + } + if m.sora_storage_used_bytes != nil { + fields = append(fields, user.FieldSoraStorageUsedBytes) + } return fields } @@ -19233,6 +22431,10 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.TotpEnabled() case user.FieldTotpEnabledAt: return m.TotpEnabledAt() + case user.FieldSoraStorageQuotaBytes: + return m.SoraStorageQuotaBytes() + case user.FieldSoraStorageUsedBytes: + return m.SoraStorageUsedBytes() } return nil, false } @@ -19270,6 +22472,10 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldTotpEnabled(ctx) case user.FieldTotpEnabledAt: return m.OldTotpEnabledAt(ctx) + case user.FieldSoraStorageQuotaBytes: + return m.OldSoraStorageQuotaBytes(ctx) + case user.FieldSoraStorageUsedBytes: + return m.OldSoraStorageUsedBytes(ctx) } return nil, fmt.Errorf("unknown User field %s", name) } @@ -19377,6 +22583,20 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetTotpEnabledAt(v) return nil + case user.FieldSoraStorageQuotaBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraStorageQuotaBytes(v) + return nil + case user.FieldSoraStorageUsedBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraStorageUsedBytes(v) + return nil } return fmt.Errorf("unknown User field %s", name) } @@ -19391,6 +22611,12 @@ func (m *UserMutation) AddedFields() []string { if m.addconcurrency != nil { fields = append(fields, user.FieldConcurrency) } + if m.addsora_storage_quota_bytes != nil { + fields = append(fields, user.FieldSoraStorageQuotaBytes) + } + if m.addsora_storage_used_bytes != nil { + fields = append(fields, user.FieldSoraStorageUsedBytes) + } return fields } @@ -19403,6 +22629,10 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) { return m.AddedBalance() case user.FieldConcurrency: return m.AddedConcurrency() + case user.FieldSoraStorageQuotaBytes: + return m.AddedSoraStorageQuotaBytes() + case user.FieldSoraStorageUsedBytes: + return m.AddedSoraStorageUsedBytes() } return nil, false } @@ -19426,6 +22656,20 @@ func (m *UserMutation) AddField(name string, value ent.Value) error { } m.AddConcurrency(v) return nil + case user.FieldSoraStorageQuotaBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraStorageQuotaBytes(v) + return nil + case user.FieldSoraStorageUsedBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraStorageUsedBytes(v) + return nil } return fmt.Errorf("unknown User numeric field %s", name) } @@ -19516,6 +22760,12 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldTotpEnabledAt: m.ResetTotpEnabledAt() return nil + case user.FieldSoraStorageQuotaBytes: + m.ResetSoraStorageQuotaBytes() + return nil + case user.FieldSoraStorageUsedBytes: + m.ResetSoraStorageUsedBytes() + return nil } return fmt.Errorf("unknown User field %s", name) } diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index c12955ef..89d933fc 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -27,6 +27,9 @@ type ErrorPassthroughRule func(*sql.Selector) // Group is the predicate function for group builders. type Group func(*sql.Selector) +// IdempotencyRecord is the predicate function for idempotencyrecord builders. +type IdempotencyRecord func(*sql.Selector) + // PromoCode is the predicate function for promocode builders. type PromoCode func(*sql.Selector) @@ -39,6 +42,9 @@ type Proxy func(*sql.Selector) // RedeemCode is the predicate function for redeemcode builders. type RedeemCode func(*sql.Selector) +// SecuritySecret is the predicate function for securitysecret builders. +type SecuritySecret func(*sql.Selector) + // Setting is the predicate function for setting builders. type Setting func(*sql.Selector) diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 7713224c..2c7467f6 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -12,11 +12,13 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/schema" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/setting" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" @@ -93,13 +95,37 @@ func init() { // apikey.StatusValidator is a validator for the "status" field. It is called by the builders before save. apikey.StatusValidator = apikeyDescStatus.Validators[0].(func(string) error) // apikeyDescQuota is the schema descriptor for quota field. - apikeyDescQuota := apikeyFields[7].Descriptor() + apikeyDescQuota := apikeyFields[8].Descriptor() // apikey.DefaultQuota holds the default value on creation for the quota field. apikey.DefaultQuota = apikeyDescQuota.Default.(float64) // apikeyDescQuotaUsed is the schema descriptor for quota_used field. - apikeyDescQuotaUsed := apikeyFields[8].Descriptor() + apikeyDescQuotaUsed := apikeyFields[9].Descriptor() // apikey.DefaultQuotaUsed holds the default value on creation for the quota_used field. apikey.DefaultQuotaUsed = apikeyDescQuotaUsed.Default.(float64) + // apikeyDescRateLimit5h is the schema descriptor for rate_limit_5h field. + apikeyDescRateLimit5h := apikeyFields[11].Descriptor() + // apikey.DefaultRateLimit5h holds the default value on creation for the rate_limit_5h field. + apikey.DefaultRateLimit5h = apikeyDescRateLimit5h.Default.(float64) + // apikeyDescRateLimit1d is the schema descriptor for rate_limit_1d field. + apikeyDescRateLimit1d := apikeyFields[12].Descriptor() + // apikey.DefaultRateLimit1d holds the default value on creation for the rate_limit_1d field. + apikey.DefaultRateLimit1d = apikeyDescRateLimit1d.Default.(float64) + // apikeyDescRateLimit7d is the schema descriptor for rate_limit_7d field. + apikeyDescRateLimit7d := apikeyFields[13].Descriptor() + // apikey.DefaultRateLimit7d holds the default value on creation for the rate_limit_7d field. + apikey.DefaultRateLimit7d = apikeyDescRateLimit7d.Default.(float64) + // apikeyDescUsage5h is the schema descriptor for usage_5h field. + apikeyDescUsage5h := apikeyFields[14].Descriptor() + // apikey.DefaultUsage5h holds the default value on creation for the usage_5h field. + apikey.DefaultUsage5h = apikeyDescUsage5h.Default.(float64) + // apikeyDescUsage1d is the schema descriptor for usage_1d field. + apikeyDescUsage1d := apikeyFields[15].Descriptor() + // apikey.DefaultUsage1d holds the default value on creation for the usage_1d field. + apikey.DefaultUsage1d = apikeyDescUsage1d.Default.(float64) + // apikeyDescUsage7d is the schema descriptor for usage_7d field. + apikeyDescUsage7d := apikeyFields[16].Descriptor() + // apikey.DefaultUsage7d holds the default value on creation for the usage_7d field. + apikey.DefaultUsage7d = apikeyDescUsage7d.Default.(float64) accountMixin := schema.Account{}.Mixin() accountMixinHooks1 := accountMixin[1].Hooks() account.Hooks[0] = accountMixinHooks1[0] @@ -208,7 +234,7 @@ func init() { // account.DefaultSchedulable holds the default value on creation for the schedulable field. account.DefaultSchedulable = accountDescSchedulable.Default.(bool) // accountDescSessionWindowStatus is the schema descriptor for session_window_status field. - accountDescSessionWindowStatus := accountFields[21].Descriptor() + accountDescSessionWindowStatus := accountFields[23].Descriptor() // account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save. account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error) accountgroupFields := schema.AccountGroup{}.Fields() @@ -397,26 +423,65 @@ func init() { groupDescDefaultValidityDays := groupFields[10].Descriptor() // group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field. group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int) + // groupDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field. + groupDescSoraStorageQuotaBytes := groupFields[18].Descriptor() + // group.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field. + group.DefaultSoraStorageQuotaBytes = groupDescSoraStorageQuotaBytes.Default.(int64) // groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field. - groupDescClaudeCodeOnly := groupFields[14].Descriptor() + groupDescClaudeCodeOnly := groupFields[19].Descriptor() // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) // groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field. - groupDescModelRoutingEnabled := groupFields[18].Descriptor() + groupDescModelRoutingEnabled := groupFields[23].Descriptor() // group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field. group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool) // groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field. - groupDescMcpXMLInject := groupFields[19].Descriptor() + groupDescMcpXMLInject := groupFields[24].Descriptor() // group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field. group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool) // groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field. - groupDescSupportedModelScopes := groupFields[20].Descriptor() + groupDescSupportedModelScopes := groupFields[25].Descriptor() // group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field. group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string) // groupDescSortOrder is the schema descriptor for sort_order field. - groupDescSortOrder := groupFields[21].Descriptor() + groupDescSortOrder := groupFields[26].Descriptor() // group.DefaultSortOrder holds the default value on creation for the sort_order field. group.DefaultSortOrder = groupDescSortOrder.Default.(int) + idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin() + idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields() + _ = idempotencyrecordMixinFields0 + idempotencyrecordFields := schema.IdempotencyRecord{}.Fields() + _ = idempotencyrecordFields + // idempotencyrecordDescCreatedAt is the schema descriptor for created_at field. + idempotencyrecordDescCreatedAt := idempotencyrecordMixinFields0[0].Descriptor() + // idempotencyrecord.DefaultCreatedAt holds the default value on creation for the created_at field. + idempotencyrecord.DefaultCreatedAt = idempotencyrecordDescCreatedAt.Default.(func() time.Time) + // idempotencyrecordDescUpdatedAt is the schema descriptor for updated_at field. + idempotencyrecordDescUpdatedAt := idempotencyrecordMixinFields0[1].Descriptor() + // idempotencyrecord.DefaultUpdatedAt holds the default value on creation for the updated_at field. + idempotencyrecord.DefaultUpdatedAt = idempotencyrecordDescUpdatedAt.Default.(func() time.Time) + // idempotencyrecord.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + idempotencyrecord.UpdateDefaultUpdatedAt = idempotencyrecordDescUpdatedAt.UpdateDefault.(func() time.Time) + // idempotencyrecordDescScope is the schema descriptor for scope field. + idempotencyrecordDescScope := idempotencyrecordFields[0].Descriptor() + // idempotencyrecord.ScopeValidator is a validator for the "scope" field. It is called by the builders before save. + idempotencyrecord.ScopeValidator = idempotencyrecordDescScope.Validators[0].(func(string) error) + // idempotencyrecordDescIdempotencyKeyHash is the schema descriptor for idempotency_key_hash field. + idempotencyrecordDescIdempotencyKeyHash := idempotencyrecordFields[1].Descriptor() + // idempotencyrecord.IdempotencyKeyHashValidator is a validator for the "idempotency_key_hash" field. It is called by the builders before save. + idempotencyrecord.IdempotencyKeyHashValidator = idempotencyrecordDescIdempotencyKeyHash.Validators[0].(func(string) error) + // idempotencyrecordDescRequestFingerprint is the schema descriptor for request_fingerprint field. + idempotencyrecordDescRequestFingerprint := idempotencyrecordFields[2].Descriptor() + // idempotencyrecord.RequestFingerprintValidator is a validator for the "request_fingerprint" field. It is called by the builders before save. + idempotencyrecord.RequestFingerprintValidator = idempotencyrecordDescRequestFingerprint.Validators[0].(func(string) error) + // idempotencyrecordDescStatus is the schema descriptor for status field. + idempotencyrecordDescStatus := idempotencyrecordFields[3].Descriptor() + // idempotencyrecord.StatusValidator is a validator for the "status" field. It is called by the builders before save. + idempotencyrecord.StatusValidator = idempotencyrecordDescStatus.Validators[0].(func(string) error) + // idempotencyrecordDescErrorReason is the schema descriptor for error_reason field. + idempotencyrecordDescErrorReason := idempotencyrecordFields[6].Descriptor() + // idempotencyrecord.ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save. + idempotencyrecord.ErrorReasonValidator = idempotencyrecordDescErrorReason.Validators[0].(func(string) error) promocodeFields := schema.PromoCode{}.Fields() _ = promocodeFields // promocodeDescCode is the schema descriptor for code field. @@ -602,6 +667,43 @@ func init() { redeemcodeDescValidityDays := redeemcodeFields[9].Descriptor() // redeemcode.DefaultValidityDays holds the default value on creation for the validity_days field. redeemcode.DefaultValidityDays = redeemcodeDescValidityDays.Default.(int) + securitysecretMixin := schema.SecuritySecret{}.Mixin() + securitysecretMixinFields0 := securitysecretMixin[0].Fields() + _ = securitysecretMixinFields0 + securitysecretFields := schema.SecuritySecret{}.Fields() + _ = securitysecretFields + // securitysecretDescCreatedAt is the schema descriptor for created_at field. + securitysecretDescCreatedAt := securitysecretMixinFields0[0].Descriptor() + // securitysecret.DefaultCreatedAt holds the default value on creation for the created_at field. + securitysecret.DefaultCreatedAt = securitysecretDescCreatedAt.Default.(func() time.Time) + // securitysecretDescUpdatedAt is the schema descriptor for updated_at field. + securitysecretDescUpdatedAt := securitysecretMixinFields0[1].Descriptor() + // securitysecret.DefaultUpdatedAt holds the default value on creation for the updated_at field. + securitysecret.DefaultUpdatedAt = securitysecretDescUpdatedAt.Default.(func() time.Time) + // securitysecret.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + securitysecret.UpdateDefaultUpdatedAt = securitysecretDescUpdatedAt.UpdateDefault.(func() time.Time) + // securitysecretDescKey is the schema descriptor for key field. + securitysecretDescKey := securitysecretFields[0].Descriptor() + // securitysecret.KeyValidator is a validator for the "key" field. It is called by the builders before save. + securitysecret.KeyValidator = func() func(string) error { + validators := securitysecretDescKey.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(key string) error { + for _, fn := range fns { + if err := fn(key); err != nil { + return err + } + } + return nil + } + }() + // securitysecretDescValue is the schema descriptor for value field. + securitysecretDescValue := securitysecretFields[1].Descriptor() + // securitysecret.ValueValidator is a validator for the "value" field. It is called by the builders before save. + securitysecret.ValueValidator = securitysecretDescValue.Validators[0].(func(string) error) settingFields := schema.Setting{}.Fields() _ = settingFields // settingDescKey is the schema descriptor for key field. @@ -779,8 +881,16 @@ func init() { usagelogDescImageSize := usagelogFields[28].Descriptor() // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) + // usagelogDescMediaType is the schema descriptor for media_type field. + usagelogDescMediaType := usagelogFields[29].Descriptor() + // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. + usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error) + // usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field. + usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor() + // usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field. + usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool) // usagelogDescCreatedAt is the schema descriptor for created_at field. - usagelogDescCreatedAt := usagelogFields[29].Descriptor() + usagelogDescCreatedAt := usagelogFields[31].Descriptor() // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) userMixin := schema.User{}.Mixin() @@ -872,6 +982,14 @@ func init() { userDescTotpEnabled := userFields[9].Descriptor() // user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field. user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool) + // userDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field. + userDescSoraStorageQuotaBytes := userFields[11].Descriptor() + // user.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field. + user.DefaultSoraStorageQuotaBytes = userDescSoraStorageQuotaBytes.Default.(int64) + // userDescSoraStorageUsedBytes is the schema descriptor for sora_storage_used_bytes field. + userDescSoraStorageUsedBytes := userFields[12].Descriptor() + // user.DefaultSoraStorageUsedBytes holds the default value on creation for the sora_storage_used_bytes field. + user.DefaultSoraStorageUsedBytes = userDescSoraStorageUsedBytes.Default.(int64) userallowedgroupFields := schema.UserAllowedGroup{}.Fields() _ = userallowedgroupFields // userallowedgroupDescCreatedAt is the schema descriptor for created_at field. diff --git a/backend/ent/schema/account.go b/backend/ent/schema/account.go index 1cfecc2d..443f9e09 100644 --- a/backend/ent/schema/account.go +++ b/backend/ent/schema/account.go @@ -164,6 +164,19 @@ func (Account) Fields() []ent.Field { Nillable(). SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + // temp_unschedulable_until: 临时不可调度状态解除时间 + // 当命中临时不可调度规则时设置,在此时间前调度器应跳过该账号 + field.Time("temp_unschedulable_until"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + + // temp_unschedulable_reason: 临时不可调度原因,便于排障审计 + field.String("temp_unschedulable_reason"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + // session_window_*: 会话窗口相关字段 // 用于管理某些需要会话时间窗口的 API(如 Claude Pro) field.Time("session_window_start"). @@ -213,6 +226,9 @@ func (Account) Indexes() []ent.Index { index.Fields("rate_limited_at"), // 筛选速率限制账户 index.Fields("rate_limit_reset_at"), // 筛选速率限制解除时间 index.Fields("overload_until"), // 筛选过载账户 - index.Fields("deleted_at"), // 软删除查询优化 + // 调度热路径复合索引(线上由 SQL 迁移创建部分索引,schema 仅用于模型可读性对齐) + index.Fields("platform", "priority"), + index.Fields("priority", "status"), + index.Fields("deleted_at"), // 软删除查询优化 } } diff --git a/backend/ent/schema/api_key.go b/backend/ent/schema/api_key.go index 26d52cb0..5db51270 100644 --- a/backend/ent/schema/api_key.go +++ b/backend/ent/schema/api_key.go @@ -47,6 +47,10 @@ func (APIKey) Fields() []ent.Field { field.String("status"). MaxLen(20). Default(domain.StatusActive), + field.Time("last_used_at"). + Optional(). + Nillable(). + Comment("Last usage time of this API key"), field.JSON("ip_whitelist", []string{}). Optional(). Comment("Allowed IPs/CIDRs, e.g. [\"192.168.1.100\", \"10.0.0.0/8\"]"), @@ -70,6 +74,47 @@ func (APIKey) Fields() []ent.Field { Optional(). Nillable(). Comment("Expiration time for this API key (null = never expires)"), + + // ========== Rate limit fields ========== + // Rate limit configuration (0 = unlimited) + field.Float("rate_limit_5h"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Rate limit in USD per 5 hours (0 = unlimited)"), + field.Float("rate_limit_1d"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Rate limit in USD per day (0 = unlimited)"), + field.Float("rate_limit_7d"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Rate limit in USD per 7 days (0 = unlimited)"), + // Rate limit usage tracking + field.Float("usage_5h"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Used amount in USD for the current 5h window"), + field.Float("usage_1d"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Used amount in USD for the current 1d window"), + field.Float("usage_7d"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Used amount in USD for the current 7d window"), + // Window start times + field.Time("window_5h_start"). + Optional(). + Nillable(). + Comment("Start time of the current 5h rate limit window"), + field.Time("window_1d_start"). + Optional(). + Nillable(). + Comment("Start time of the current 1d rate limit window"), + field.Time("window_7d_start"). + Optional(). + Nillable(). + Comment("Start time of the current 7d rate limit window"), } } @@ -95,6 +140,7 @@ func (APIKey) Indexes() []ent.Index { index.Fields("group_id"), index.Fields("status"), index.Fields("deleted_at"), + index.Fields("last_used_at"), // Index for quota queries index.Fields("quota", "quota_used"), index.Fields("expires_at"), diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index c36ca770..3fcf8674 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -87,6 +87,28 @@ func (Group) Fields() []ent.Field { Nillable(). SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + // Sora 按次计费配置(阶段 1) + field.Float("sora_image_price_360"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + field.Float("sora_image_price_540"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + field.Float("sora_video_price_per_request"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + field.Float("sora_video_price_per_request_hd"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + + // Sora 存储配额 + field.Int64("sora_storage_quota_bytes"). + Default(0), + // Claude Code 客户端限制 (added by migration 029) field.Bool("claude_code_only"). Default(false). diff --git a/backend/ent/schema/idempotency_record.go b/backend/ent/schema/idempotency_record.go new file mode 100644 index 00000000..ed09ad65 --- /dev/null +++ b/backend/ent/schema/idempotency_record.go @@ -0,0 +1,50 @@ +package schema + +import ( + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// IdempotencyRecord 幂等请求记录表。 +type IdempotencyRecord struct { + ent.Schema +} + +func (IdempotencyRecord) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "idempotency_records"}, + } +} + +func (IdempotencyRecord) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (IdempotencyRecord) Fields() []ent.Field { + return []ent.Field{ + field.String("scope").MaxLen(128), + field.String("idempotency_key_hash").MaxLen(64), + field.String("request_fingerprint").MaxLen(64), + field.String("status").MaxLen(32), + field.Int("response_status").Optional().Nillable(), + field.String("response_body").Optional().Nillable(), + field.String("error_reason").MaxLen(128).Optional().Nillable(), + field.Time("locked_until").Optional().Nillable(), + field.Time("expires_at"), + } +} + +func (IdempotencyRecord) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("scope", "idempotency_key_hash").Unique(), + index.Fields("expires_at"), + index.Fields("status", "locked_until"), + } +} diff --git a/backend/ent/schema/security_secret.go b/backend/ent/schema/security_secret.go new file mode 100644 index 00000000..ffe6d348 --- /dev/null +++ b/backend/ent/schema/security_secret.go @@ -0,0 +1,42 @@ +package schema + +import ( + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/field" +) + +// SecuritySecret 存储系统级安全密钥(如 JWT 签名密钥、TOTP 加密密钥)。 +type SecuritySecret struct { + ent.Schema +} + +func (SecuritySecret) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "security_secrets"}, + } +} + +func (SecuritySecret) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (SecuritySecret) Fields() []ent.Field { + return []ent.Field{ + field.String("key"). + MaxLen(100). + NotEmpty(). + Unique(), + field.String("value"). + NotEmpty(). + SchemaType(map[string]string{ + dialect.Postgres: "text", + }), + } +} diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index fc7c7165..dcca1a0a 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -118,6 +118,15 @@ func (UsageLog) Fields() []ent.Field { MaxLen(10). Optional(). Nillable(), + // 媒体类型字段(sora 使用) + field.String("media_type"). + MaxLen(16). + Optional(). + Nillable(), + + // Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费) + field.Bool("cache_ttl_overridden"). + Default(false), // 时间戳(只有 created_at,日志不可修改) field.Time("created_at"). @@ -170,5 +179,7 @@ func (UsageLog) Indexes() []ent.Index { // 复合索引用于时间范围查询 index.Fields("user_id", "created_at"), index.Fields("api_key_id", "created_at"), + // 分组维度时间范围查询(线上由 SQL 迁移创建 group_id IS NOT NULL 的部分索引) + index.Fields("group_id", "created_at"), } } diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index d443ef45..0a3b5d9e 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -72,6 +72,12 @@ func (User) Fields() []ent.Field { field.Time("totp_enabled_at"). Optional(). Nillable(), + + // Sora 存储配额 + field.Int64("sora_storage_quota_bytes"). + Default(0), + field.Int64("sora_storage_used_bytes"). + Default(0), } } diff --git a/backend/ent/schema/user_subscription.go b/backend/ent/schema/user_subscription.go index fa13612b..a81850b1 100644 --- a/backend/ent/schema/user_subscription.go +++ b/backend/ent/schema/user_subscription.go @@ -108,6 +108,8 @@ func (UserSubscription) Indexes() []ent.Index { index.Fields("group_id"), index.Fields("status"), index.Fields("expires_at"), + // 活跃订阅查询复合索引(线上由 SQL 迁移创建部分索引,schema 仅用于模型可读性对齐) + index.Fields("user_id", "status", "expires_at"), index.Fields("assigned_by"), // 唯一约束通过部分索引实现(WHERE deleted_at IS NULL),支持软删除后重新订阅 // 见迁移文件 016_soft_delete_partial_unique_indexes.sql diff --git a/backend/ent/securitysecret.go b/backend/ent/securitysecret.go new file mode 100644 index 00000000..e0e93c91 --- /dev/null +++ b/backend/ent/securitysecret.go @@ -0,0 +1,139 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" +) + +// SecuritySecret is the model entity for the SecuritySecret schema. +type SecuritySecret struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // Key holds the value of the "key" field. + Key string `json:"key,omitempty"` + // Value holds the value of the "value" field. + Value string `json:"value,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*SecuritySecret) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case securitysecret.FieldID: + values[i] = new(sql.NullInt64) + case securitysecret.FieldKey, securitysecret.FieldValue: + values[i] = new(sql.NullString) + case securitysecret.FieldCreatedAt, securitysecret.FieldUpdatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the SecuritySecret fields. +func (_m *SecuritySecret) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case securitysecret.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case securitysecret.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case securitysecret.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case securitysecret.FieldKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field key", values[i]) + } else if value.Valid { + _m.Key = value.String + } + case securitysecret.FieldValue: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field value", values[i]) + } else if value.Valid { + _m.Value = value.String + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// GetValue returns the ent.Value that was dynamically selected and assigned to the SecuritySecret. +// This includes values selected through modifiers, order, etc. +func (_m *SecuritySecret) GetValue(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// Update returns a builder for updating this SecuritySecret. +// Note that you need to call SecuritySecret.Unwrap() before calling this method if this SecuritySecret +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *SecuritySecret) Update() *SecuritySecretUpdateOne { + return NewSecuritySecretClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the SecuritySecret entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *SecuritySecret) Unwrap() *SecuritySecret { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: SecuritySecret is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *SecuritySecret) String() string { + var builder strings.Builder + builder.WriteString("SecuritySecret(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("key=") + builder.WriteString(_m.Key) + builder.WriteString(", ") + builder.WriteString("value=") + builder.WriteString(_m.Value) + builder.WriteByte(')') + return builder.String() +} + +// SecuritySecrets is a parsable slice of SecuritySecret. +type SecuritySecrets []*SecuritySecret diff --git a/backend/ent/securitysecret/securitysecret.go b/backend/ent/securitysecret/securitysecret.go new file mode 100644 index 00000000..4c5d9ef6 --- /dev/null +++ b/backend/ent/securitysecret/securitysecret.go @@ -0,0 +1,86 @@ +// Code generated by ent, DO NOT EDIT. + +package securitysecret + +import ( + "time" + + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the securitysecret type in the database. + Label = "security_secret" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldKey holds the string denoting the key field in the database. + FieldKey = "key" + // FieldValue holds the string denoting the value field in the database. + FieldValue = "value" + // Table holds the table name of the securitysecret in the database. + Table = "security_secrets" +) + +// Columns holds all SQL columns for securitysecret fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldKey, + FieldValue, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // KeyValidator is a validator for the "key" field. It is called by the builders before save. + KeyValidator func(string) error + // ValueValidator is a validator for the "value" field. It is called by the builders before save. + ValueValidator func(string) error +) + +// OrderOption defines the ordering options for the SecuritySecret queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByKey orders the results by the key field. +func ByKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldKey, opts...).ToFunc() +} + +// ByValue orders the results by the value field. +func ByValue(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldValue, opts...).ToFunc() +} diff --git a/backend/ent/securitysecret/where.go b/backend/ent/securitysecret/where.go new file mode 100644 index 00000000..34f50752 --- /dev/null +++ b/backend/ent/securitysecret/where.go @@ -0,0 +1,300 @@ +// Code generated by ent, DO NOT EDIT. + +package securitysecret + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// Key applies equality check predicate on the "key" field. It's identical to KeyEQ. +func Key(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldKey, v)) +} + +// Value applies equality check predicate on the "value" field. It's identical to ValueEQ. +func Value(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldValue, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// KeyEQ applies the EQ predicate on the "key" field. +func KeyEQ(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldKey, v)) +} + +// KeyNEQ applies the NEQ predicate on the "key" field. +func KeyNEQ(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNEQ(FieldKey, v)) +} + +// KeyIn applies the In predicate on the "key" field. +func KeyIn(vs ...string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldIn(FieldKey, vs...)) +} + +// KeyNotIn applies the NotIn predicate on the "key" field. +func KeyNotIn(vs ...string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNotIn(FieldKey, vs...)) +} + +// KeyGT applies the GT predicate on the "key" field. +func KeyGT(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGT(FieldKey, v)) +} + +// KeyGTE applies the GTE predicate on the "key" field. +func KeyGTE(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGTE(FieldKey, v)) +} + +// KeyLT applies the LT predicate on the "key" field. +func KeyLT(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLT(FieldKey, v)) +} + +// KeyLTE applies the LTE predicate on the "key" field. +func KeyLTE(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLTE(FieldKey, v)) +} + +// KeyContains applies the Contains predicate on the "key" field. +func KeyContains(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldContains(FieldKey, v)) +} + +// KeyHasPrefix applies the HasPrefix predicate on the "key" field. +func KeyHasPrefix(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldHasPrefix(FieldKey, v)) +} + +// KeyHasSuffix applies the HasSuffix predicate on the "key" field. +func KeyHasSuffix(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldHasSuffix(FieldKey, v)) +} + +// KeyEqualFold applies the EqualFold predicate on the "key" field. +func KeyEqualFold(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEqualFold(FieldKey, v)) +} + +// KeyContainsFold applies the ContainsFold predicate on the "key" field. +func KeyContainsFold(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldContainsFold(FieldKey, v)) +} + +// ValueEQ applies the EQ predicate on the "value" field. +func ValueEQ(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldValue, v)) +} + +// ValueNEQ applies the NEQ predicate on the "value" field. +func ValueNEQ(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNEQ(FieldValue, v)) +} + +// ValueIn applies the In predicate on the "value" field. +func ValueIn(vs ...string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldIn(FieldValue, vs...)) +} + +// ValueNotIn applies the NotIn predicate on the "value" field. +func ValueNotIn(vs ...string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNotIn(FieldValue, vs...)) +} + +// ValueGT applies the GT predicate on the "value" field. +func ValueGT(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGT(FieldValue, v)) +} + +// ValueGTE applies the GTE predicate on the "value" field. +func ValueGTE(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGTE(FieldValue, v)) +} + +// ValueLT applies the LT predicate on the "value" field. +func ValueLT(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLT(FieldValue, v)) +} + +// ValueLTE applies the LTE predicate on the "value" field. +func ValueLTE(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLTE(FieldValue, v)) +} + +// ValueContains applies the Contains predicate on the "value" field. +func ValueContains(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldContains(FieldValue, v)) +} + +// ValueHasPrefix applies the HasPrefix predicate on the "value" field. +func ValueHasPrefix(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldHasPrefix(FieldValue, v)) +} + +// ValueHasSuffix applies the HasSuffix predicate on the "value" field. +func ValueHasSuffix(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldHasSuffix(FieldValue, v)) +} + +// ValueEqualFold applies the EqualFold predicate on the "value" field. +func ValueEqualFold(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEqualFold(FieldValue, v)) +} + +// ValueContainsFold applies the ContainsFold predicate on the "value" field. +func ValueContainsFold(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldContainsFold(FieldValue, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.SecuritySecret) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.SecuritySecret) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.SecuritySecret) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.NotPredicates(p)) +} diff --git a/backend/ent/securitysecret_create.go b/backend/ent/securitysecret_create.go new file mode 100644 index 00000000..397503be --- /dev/null +++ b/backend/ent/securitysecret_create.go @@ -0,0 +1,626 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" +) + +// SecuritySecretCreate is the builder for creating a SecuritySecret entity. +type SecuritySecretCreate struct { + config + mutation *SecuritySecretMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *SecuritySecretCreate) SetCreatedAt(v time.Time) *SecuritySecretCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *SecuritySecretCreate) SetNillableCreatedAt(v *time.Time) *SecuritySecretCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *SecuritySecretCreate) SetUpdatedAt(v time.Time) *SecuritySecretCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *SecuritySecretCreate) SetNillableUpdatedAt(v *time.Time) *SecuritySecretCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetKey sets the "key" field. +func (_c *SecuritySecretCreate) SetKey(v string) *SecuritySecretCreate { + _c.mutation.SetKey(v) + return _c +} + +// SetValue sets the "value" field. +func (_c *SecuritySecretCreate) SetValue(v string) *SecuritySecretCreate { + _c.mutation.SetValue(v) + return _c +} + +// Mutation returns the SecuritySecretMutation object of the builder. +func (_c *SecuritySecretCreate) Mutation() *SecuritySecretMutation { + return _c.mutation +} + +// Save creates the SecuritySecret in the database. +func (_c *SecuritySecretCreate) Save(ctx context.Context) (*SecuritySecret, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *SecuritySecretCreate) SaveX(ctx context.Context) *SecuritySecret { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *SecuritySecretCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *SecuritySecretCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *SecuritySecretCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := securitysecret.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := securitysecret.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *SecuritySecretCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "SecuritySecret.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "SecuritySecret.updated_at"`)} + } + if _, ok := _c.mutation.Key(); !ok { + return &ValidationError{Name: "key", err: errors.New(`ent: missing required field "SecuritySecret.key"`)} + } + if v, ok := _c.mutation.Key(); ok { + if err := securitysecret.KeyValidator(v); err != nil { + return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.key": %w`, err)} + } + } + if _, ok := _c.mutation.Value(); !ok { + return &ValidationError{Name: "value", err: errors.New(`ent: missing required field "SecuritySecret.value"`)} + } + if v, ok := _c.mutation.Value(); ok { + if err := securitysecret.ValueValidator(v); err != nil { + return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.value": %w`, err)} + } + } + return nil +} + +func (_c *SecuritySecretCreate) sqlSave(ctx context.Context) (*SecuritySecret, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *SecuritySecretCreate) createSpec() (*SecuritySecret, *sqlgraph.CreateSpec) { + var ( + _node = &SecuritySecret{config: _c.config} + _spec = sqlgraph.NewCreateSpec(securitysecret.Table, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(securitysecret.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(securitysecret.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.Key(); ok { + _spec.SetField(securitysecret.FieldKey, field.TypeString, value) + _node.Key = value + } + if value, ok := _c.mutation.Value(); ok { + _spec.SetField(securitysecret.FieldValue, field.TypeString, value) + _node.Value = value + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.SecuritySecret.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.SecuritySecretUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *SecuritySecretCreate) OnConflict(opts ...sql.ConflictOption) *SecuritySecretUpsertOne { + _c.conflict = opts + return &SecuritySecretUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.SecuritySecret.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *SecuritySecretCreate) OnConflictColumns(columns ...string) *SecuritySecretUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &SecuritySecretUpsertOne{ + create: _c, + } +} + +type ( + // SecuritySecretUpsertOne is the builder for "upsert"-ing + // one SecuritySecret node. + SecuritySecretUpsertOne struct { + create *SecuritySecretCreate + } + + // SecuritySecretUpsert is the "OnConflict" setter. + SecuritySecretUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *SecuritySecretUpsert) SetUpdatedAt(v time.Time) *SecuritySecretUpsert { + u.Set(securitysecret.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *SecuritySecretUpsert) UpdateUpdatedAt() *SecuritySecretUpsert { + u.SetExcluded(securitysecret.FieldUpdatedAt) + return u +} + +// SetKey sets the "key" field. +func (u *SecuritySecretUpsert) SetKey(v string) *SecuritySecretUpsert { + u.Set(securitysecret.FieldKey, v) + return u +} + +// UpdateKey sets the "key" field to the value that was provided on create. +func (u *SecuritySecretUpsert) UpdateKey() *SecuritySecretUpsert { + u.SetExcluded(securitysecret.FieldKey) + return u +} + +// SetValue sets the "value" field. +func (u *SecuritySecretUpsert) SetValue(v string) *SecuritySecretUpsert { + u.Set(securitysecret.FieldValue, v) + return u +} + +// UpdateValue sets the "value" field to the value that was provided on create. +func (u *SecuritySecretUpsert) UpdateValue() *SecuritySecretUpsert { + u.SetExcluded(securitysecret.FieldValue) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.SecuritySecret.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *SecuritySecretUpsertOne) UpdateNewValues() *SecuritySecretUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(securitysecret.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.SecuritySecret.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *SecuritySecretUpsertOne) Ignore() *SecuritySecretUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *SecuritySecretUpsertOne) DoNothing() *SecuritySecretUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the SecuritySecretCreate.OnConflict +// documentation for more info. +func (u *SecuritySecretUpsertOne) Update(set func(*SecuritySecretUpsert)) *SecuritySecretUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&SecuritySecretUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *SecuritySecretUpsertOne) SetUpdatedAt(v time.Time) *SecuritySecretUpsertOne { + return u.Update(func(s *SecuritySecretUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *SecuritySecretUpsertOne) UpdateUpdatedAt() *SecuritySecretUpsertOne { + return u.Update(func(s *SecuritySecretUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetKey sets the "key" field. +func (u *SecuritySecretUpsertOne) SetKey(v string) *SecuritySecretUpsertOne { + return u.Update(func(s *SecuritySecretUpsert) { + s.SetKey(v) + }) +} + +// UpdateKey sets the "key" field to the value that was provided on create. +func (u *SecuritySecretUpsertOne) UpdateKey() *SecuritySecretUpsertOne { + return u.Update(func(s *SecuritySecretUpsert) { + s.UpdateKey() + }) +} + +// SetValue sets the "value" field. +func (u *SecuritySecretUpsertOne) SetValue(v string) *SecuritySecretUpsertOne { + return u.Update(func(s *SecuritySecretUpsert) { + s.SetValue(v) + }) +} + +// UpdateValue sets the "value" field to the value that was provided on create. +func (u *SecuritySecretUpsertOne) UpdateValue() *SecuritySecretUpsertOne { + return u.Update(func(s *SecuritySecretUpsert) { + s.UpdateValue() + }) +} + +// Exec executes the query. +func (u *SecuritySecretUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for SecuritySecretCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *SecuritySecretUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *SecuritySecretUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *SecuritySecretUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// SecuritySecretCreateBulk is the builder for creating many SecuritySecret entities in bulk. +type SecuritySecretCreateBulk struct { + config + err error + builders []*SecuritySecretCreate + conflict []sql.ConflictOption +} + +// Save creates the SecuritySecret entities in the database. +func (_c *SecuritySecretCreateBulk) Save(ctx context.Context) ([]*SecuritySecret, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*SecuritySecret, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*SecuritySecretMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *SecuritySecretCreateBulk) SaveX(ctx context.Context) []*SecuritySecret { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *SecuritySecretCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *SecuritySecretCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.SecuritySecret.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.SecuritySecretUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *SecuritySecretCreateBulk) OnConflict(opts ...sql.ConflictOption) *SecuritySecretUpsertBulk { + _c.conflict = opts + return &SecuritySecretUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.SecuritySecret.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *SecuritySecretCreateBulk) OnConflictColumns(columns ...string) *SecuritySecretUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &SecuritySecretUpsertBulk{ + create: _c, + } +} + +// SecuritySecretUpsertBulk is the builder for "upsert"-ing +// a bulk of SecuritySecret nodes. +type SecuritySecretUpsertBulk struct { + create *SecuritySecretCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.SecuritySecret.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *SecuritySecretUpsertBulk) UpdateNewValues() *SecuritySecretUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(securitysecret.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.SecuritySecret.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *SecuritySecretUpsertBulk) Ignore() *SecuritySecretUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *SecuritySecretUpsertBulk) DoNothing() *SecuritySecretUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the SecuritySecretCreateBulk.OnConflict +// documentation for more info. +func (u *SecuritySecretUpsertBulk) Update(set func(*SecuritySecretUpsert)) *SecuritySecretUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&SecuritySecretUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *SecuritySecretUpsertBulk) SetUpdatedAt(v time.Time) *SecuritySecretUpsertBulk { + return u.Update(func(s *SecuritySecretUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *SecuritySecretUpsertBulk) UpdateUpdatedAt() *SecuritySecretUpsertBulk { + return u.Update(func(s *SecuritySecretUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetKey sets the "key" field. +func (u *SecuritySecretUpsertBulk) SetKey(v string) *SecuritySecretUpsertBulk { + return u.Update(func(s *SecuritySecretUpsert) { + s.SetKey(v) + }) +} + +// UpdateKey sets the "key" field to the value that was provided on create. +func (u *SecuritySecretUpsertBulk) UpdateKey() *SecuritySecretUpsertBulk { + return u.Update(func(s *SecuritySecretUpsert) { + s.UpdateKey() + }) +} + +// SetValue sets the "value" field. +func (u *SecuritySecretUpsertBulk) SetValue(v string) *SecuritySecretUpsertBulk { + return u.Update(func(s *SecuritySecretUpsert) { + s.SetValue(v) + }) +} + +// UpdateValue sets the "value" field to the value that was provided on create. +func (u *SecuritySecretUpsertBulk) UpdateValue() *SecuritySecretUpsertBulk { + return u.Update(func(s *SecuritySecretUpsert) { + s.UpdateValue() + }) +} + +// Exec executes the query. +func (u *SecuritySecretUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the SecuritySecretCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for SecuritySecretCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *SecuritySecretUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/securitysecret_delete.go b/backend/ent/securitysecret_delete.go new file mode 100644 index 00000000..66757138 --- /dev/null +++ b/backend/ent/securitysecret_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" +) + +// SecuritySecretDelete is the builder for deleting a SecuritySecret entity. +type SecuritySecretDelete struct { + config + hooks []Hook + mutation *SecuritySecretMutation +} + +// Where appends a list predicates to the SecuritySecretDelete builder. +func (_d *SecuritySecretDelete) Where(ps ...predicate.SecuritySecret) *SecuritySecretDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *SecuritySecretDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *SecuritySecretDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *SecuritySecretDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(securitysecret.Table, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// SecuritySecretDeleteOne is the builder for deleting a single SecuritySecret entity. +type SecuritySecretDeleteOne struct { + _d *SecuritySecretDelete +} + +// Where appends a list predicates to the SecuritySecretDelete builder. +func (_d *SecuritySecretDeleteOne) Where(ps ...predicate.SecuritySecret) *SecuritySecretDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *SecuritySecretDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{securitysecret.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *SecuritySecretDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/securitysecret_query.go b/backend/ent/securitysecret_query.go new file mode 100644 index 00000000..fe53adf1 --- /dev/null +++ b/backend/ent/securitysecret_query.go @@ -0,0 +1,564 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" +) + +// SecuritySecretQuery is the builder for querying SecuritySecret entities. +type SecuritySecretQuery struct { + config + ctx *QueryContext + order []securitysecret.OrderOption + inters []Interceptor + predicates []predicate.SecuritySecret + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the SecuritySecretQuery builder. +func (_q *SecuritySecretQuery) Where(ps ...predicate.SecuritySecret) *SecuritySecretQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *SecuritySecretQuery) Limit(limit int) *SecuritySecretQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *SecuritySecretQuery) Offset(offset int) *SecuritySecretQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *SecuritySecretQuery) Unique(unique bool) *SecuritySecretQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *SecuritySecretQuery) Order(o ...securitysecret.OrderOption) *SecuritySecretQuery { + _q.order = append(_q.order, o...) + return _q +} + +// First returns the first SecuritySecret entity from the query. +// Returns a *NotFoundError when no SecuritySecret was found. +func (_q *SecuritySecretQuery) First(ctx context.Context) (*SecuritySecret, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{securitysecret.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *SecuritySecretQuery) FirstX(ctx context.Context) *SecuritySecret { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first SecuritySecret ID from the query. +// Returns a *NotFoundError when no SecuritySecret ID was found. +func (_q *SecuritySecretQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{securitysecret.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *SecuritySecretQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single SecuritySecret entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one SecuritySecret entity is found. +// Returns a *NotFoundError when no SecuritySecret entities are found. +func (_q *SecuritySecretQuery) Only(ctx context.Context) (*SecuritySecret, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{securitysecret.Label} + default: + return nil, &NotSingularError{securitysecret.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *SecuritySecretQuery) OnlyX(ctx context.Context) *SecuritySecret { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only SecuritySecret ID in the query. +// Returns a *NotSingularError when more than one SecuritySecret ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *SecuritySecretQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{securitysecret.Label} + default: + err = &NotSingularError{securitysecret.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *SecuritySecretQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of SecuritySecrets. +func (_q *SecuritySecretQuery) All(ctx context.Context) ([]*SecuritySecret, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*SecuritySecret, *SecuritySecretQuery]() + return withInterceptors[[]*SecuritySecret](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *SecuritySecretQuery) AllX(ctx context.Context) []*SecuritySecret { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of SecuritySecret IDs. +func (_q *SecuritySecretQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(securitysecret.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *SecuritySecretQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *SecuritySecretQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*SecuritySecretQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *SecuritySecretQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *SecuritySecretQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *SecuritySecretQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the SecuritySecretQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *SecuritySecretQuery) Clone() *SecuritySecretQuery { + if _q == nil { + return nil + } + return &SecuritySecretQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]securitysecret.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.SecuritySecret{}, _q.predicates...), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.SecuritySecret.Query(). +// GroupBy(securitysecret.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *SecuritySecretQuery) GroupBy(field string, fields ...string) *SecuritySecretGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &SecuritySecretGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = securitysecret.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.SecuritySecret.Query(). +// Select(securitysecret.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *SecuritySecretQuery) Select(fields ...string) *SecuritySecretSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &SecuritySecretSelect{SecuritySecretQuery: _q} + sbuild.label = securitysecret.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a SecuritySecretSelect configured with the given aggregations. +func (_q *SecuritySecretQuery) Aggregate(fns ...AggregateFunc) *SecuritySecretSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *SecuritySecretQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !securitysecret.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *SecuritySecretQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*SecuritySecret, error) { + var ( + nodes = []*SecuritySecret{} + _spec = _q.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*SecuritySecret).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &SecuritySecret{config: _q.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (_q *SecuritySecretQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *SecuritySecretQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(securitysecret.Table, securitysecret.Columns, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, securitysecret.FieldID) + for i := range fields { + if fields[i] != securitysecret.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *SecuritySecretQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(securitysecret.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = securitysecret.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *SecuritySecretQuery) ForUpdate(opts ...sql.LockOption) *SecuritySecretQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *SecuritySecretQuery) ForShare(opts ...sql.LockOption) *SecuritySecretQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// SecuritySecretGroupBy is the group-by builder for SecuritySecret entities. +type SecuritySecretGroupBy struct { + selector + build *SecuritySecretQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *SecuritySecretGroupBy) Aggregate(fns ...AggregateFunc) *SecuritySecretGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *SecuritySecretGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*SecuritySecretQuery, *SecuritySecretGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *SecuritySecretGroupBy) sqlScan(ctx context.Context, root *SecuritySecretQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// SecuritySecretSelect is the builder for selecting fields of SecuritySecret entities. +type SecuritySecretSelect struct { + *SecuritySecretQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *SecuritySecretSelect) Aggregate(fns ...AggregateFunc) *SecuritySecretSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *SecuritySecretSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*SecuritySecretQuery, *SecuritySecretSelect](ctx, _s.SecuritySecretQuery, _s, _s.inters, v) +} + +func (_s *SecuritySecretSelect) sqlScan(ctx context.Context, root *SecuritySecretQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/securitysecret_update.go b/backend/ent/securitysecret_update.go new file mode 100644 index 00000000..ec3979af --- /dev/null +++ b/backend/ent/securitysecret_update.go @@ -0,0 +1,316 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" +) + +// SecuritySecretUpdate is the builder for updating SecuritySecret entities. +type SecuritySecretUpdate struct { + config + hooks []Hook + mutation *SecuritySecretMutation +} + +// Where appends a list predicates to the SecuritySecretUpdate builder. +func (_u *SecuritySecretUpdate) Where(ps ...predicate.SecuritySecret) *SecuritySecretUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *SecuritySecretUpdate) SetUpdatedAt(v time.Time) *SecuritySecretUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetKey sets the "key" field. +func (_u *SecuritySecretUpdate) SetKey(v string) *SecuritySecretUpdate { + _u.mutation.SetKey(v) + return _u +} + +// SetNillableKey sets the "key" field if the given value is not nil. +func (_u *SecuritySecretUpdate) SetNillableKey(v *string) *SecuritySecretUpdate { + if v != nil { + _u.SetKey(*v) + } + return _u +} + +// SetValue sets the "value" field. +func (_u *SecuritySecretUpdate) SetValue(v string) *SecuritySecretUpdate { + _u.mutation.SetValue(v) + return _u +} + +// SetNillableValue sets the "value" field if the given value is not nil. +func (_u *SecuritySecretUpdate) SetNillableValue(v *string) *SecuritySecretUpdate { + if v != nil { + _u.SetValue(*v) + } + return _u +} + +// Mutation returns the SecuritySecretMutation object of the builder. +func (_u *SecuritySecretUpdate) Mutation() *SecuritySecretMutation { + return _u.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *SecuritySecretUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *SecuritySecretUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *SecuritySecretUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *SecuritySecretUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *SecuritySecretUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := securitysecret.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *SecuritySecretUpdate) check() error { + if v, ok := _u.mutation.Key(); ok { + if err := securitysecret.KeyValidator(v); err != nil { + return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.key": %w`, err)} + } + } + if v, ok := _u.mutation.Value(); ok { + if err := securitysecret.ValueValidator(v); err != nil { + return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.value": %w`, err)} + } + } + return nil +} + +func (_u *SecuritySecretUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(securitysecret.Table, securitysecret.Columns, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(securitysecret.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Key(); ok { + _spec.SetField(securitysecret.FieldKey, field.TypeString, value) + } + if value, ok := _u.mutation.Value(); ok { + _spec.SetField(securitysecret.FieldValue, field.TypeString, value) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{securitysecret.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// SecuritySecretUpdateOne is the builder for updating a single SecuritySecret entity. +type SecuritySecretUpdateOne struct { + config + fields []string + hooks []Hook + mutation *SecuritySecretMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *SecuritySecretUpdateOne) SetUpdatedAt(v time.Time) *SecuritySecretUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetKey sets the "key" field. +func (_u *SecuritySecretUpdateOne) SetKey(v string) *SecuritySecretUpdateOne { + _u.mutation.SetKey(v) + return _u +} + +// SetNillableKey sets the "key" field if the given value is not nil. +func (_u *SecuritySecretUpdateOne) SetNillableKey(v *string) *SecuritySecretUpdateOne { + if v != nil { + _u.SetKey(*v) + } + return _u +} + +// SetValue sets the "value" field. +func (_u *SecuritySecretUpdateOne) SetValue(v string) *SecuritySecretUpdateOne { + _u.mutation.SetValue(v) + return _u +} + +// SetNillableValue sets the "value" field if the given value is not nil. +func (_u *SecuritySecretUpdateOne) SetNillableValue(v *string) *SecuritySecretUpdateOne { + if v != nil { + _u.SetValue(*v) + } + return _u +} + +// Mutation returns the SecuritySecretMutation object of the builder. +func (_u *SecuritySecretUpdateOne) Mutation() *SecuritySecretMutation { + return _u.mutation +} + +// Where appends a list predicates to the SecuritySecretUpdate builder. +func (_u *SecuritySecretUpdateOne) Where(ps ...predicate.SecuritySecret) *SecuritySecretUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *SecuritySecretUpdateOne) Select(field string, fields ...string) *SecuritySecretUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated SecuritySecret entity. +func (_u *SecuritySecretUpdateOne) Save(ctx context.Context) (*SecuritySecret, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *SecuritySecretUpdateOne) SaveX(ctx context.Context) *SecuritySecret { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *SecuritySecretUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *SecuritySecretUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *SecuritySecretUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := securitysecret.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *SecuritySecretUpdateOne) check() error { + if v, ok := _u.mutation.Key(); ok { + if err := securitysecret.KeyValidator(v); err != nil { + return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.key": %w`, err)} + } + } + if v, ok := _u.mutation.Value(); ok { + if err := securitysecret.ValueValidator(v); err != nil { + return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.value": %w`, err)} + } + } + return nil +} + +func (_u *SecuritySecretUpdateOne) sqlSave(ctx context.Context) (_node *SecuritySecret, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(securitysecret.Table, securitysecret.Columns, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "SecuritySecret.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, securitysecret.FieldID) + for _, f := range fields { + if !securitysecret.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != securitysecret.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(securitysecret.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Key(); ok { + _spec.SetField(securitysecret.FieldKey, field.TypeString, value) + } + if value, ok := _u.mutation.Value(); ok { + _spec.SetField(securitysecret.FieldValue, field.TypeString, value) + } + _node = &SecuritySecret{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{securitysecret.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/tx.go b/backend/ent/tx.go index 45d83428..cd3b2296 100644 --- a/backend/ent/tx.go +++ b/backend/ent/tx.go @@ -28,6 +28,8 @@ type Tx struct { ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient + // IdempotencyRecord is the client for interacting with the IdempotencyRecord builders. + IdempotencyRecord *IdempotencyRecordClient // PromoCode is the client for interacting with the PromoCode builders. PromoCode *PromoCodeClient // PromoCodeUsage is the client for interacting with the PromoCodeUsage builders. @@ -36,6 +38,8 @@ type Tx struct { Proxy *ProxyClient // RedeemCode is the client for interacting with the RedeemCode builders. RedeemCode *RedeemCodeClient + // SecuritySecret is the client for interacting with the SecuritySecret builders. + SecuritySecret *SecuritySecretClient // Setting is the client for interacting with the Setting builders. Setting *SettingClient // UsageCleanupTask is the client for interacting with the UsageCleanupTask builders. @@ -190,10 +194,12 @@ func (tx *Tx) init() { tx.AnnouncementRead = NewAnnouncementReadClient(tx.config) tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config) tx.Group = NewGroupClient(tx.config) + tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config) tx.PromoCode = NewPromoCodeClient(tx.config) tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config) tx.Proxy = NewProxyClient(tx.config) tx.RedeemCode = NewRedeemCodeClient(tx.config) + tx.SecuritySecret = NewSecuritySecretClient(tx.config) tx.Setting = NewSettingClient(tx.config) tx.UsageCleanupTask = NewUsageCleanupTaskClient(tx.config) tx.UsageLog = NewUsageLogClient(tx.config) diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go index 81c466b4..f6968d0d 100644 --- a/backend/ent/usagelog.go +++ b/backend/ent/usagelog.go @@ -80,6 +80,10 @@ type UsageLog struct { ImageCount int `json:"image_count,omitempty"` // ImageSize holds the value of the "image_size" field. ImageSize *string `json:"image_size,omitempty"` + // MediaType holds the value of the "media_type" field. + MediaType *string `json:"media_type,omitempty"` + // CacheTTLOverridden holds the value of the "cache_ttl_overridden" field. + CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` // Edges holds the relations/edges for other nodes in the graph. @@ -165,13 +169,13 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case usagelog.FieldStream: + case usagelog.FieldStream, usagelog.FieldCacheTTLOverridden: values[i] = new(sql.NullBool) case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier: values[i] = new(sql.NullFloat64) case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: values[i] = new(sql.NullInt64) - case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize: + case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: values[i] = new(sql.NullString) case usagelog.FieldCreatedAt: values[i] = new(sql.NullTime) @@ -378,6 +382,19 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { _m.ImageSize = new(string) *_m.ImageSize = value.String } + case usagelog.FieldMediaType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field media_type", values[i]) + } else if value.Valid { + _m.MediaType = new(string) + *_m.MediaType = value.String + } + case usagelog.FieldCacheTTLOverridden: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i]) + } else if value.Valid { + _m.CacheTTLOverridden = value.Bool + } case usagelog.FieldCreatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) @@ -548,6 +565,14 @@ func (_m *UsageLog) String() string { builder.WriteString(*v) } builder.WriteString(", ") + if v := _m.MediaType; v != nil { + builder.WriteString("media_type=") + builder.WriteString(*v) + } + builder.WriteString(", ") + builder.WriteString("cache_ttl_overridden=") + builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden)) + builder.WriteString(", ") builder.WriteString("created_at=") builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) builder.WriteByte(')') diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go index 980f1e58..ba97b843 100644 --- a/backend/ent/usagelog/usagelog.go +++ b/backend/ent/usagelog/usagelog.go @@ -72,6 +72,10 @@ const ( FieldImageCount = "image_count" // FieldImageSize holds the string denoting the image_size field in the database. FieldImageSize = "image_size" + // FieldMediaType holds the string denoting the media_type field in the database. + FieldMediaType = "media_type" + // FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database. + FieldCacheTTLOverridden = "cache_ttl_overridden" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" // EdgeUser holds the string denoting the user edge name in mutations. @@ -155,6 +159,8 @@ var Columns = []string{ FieldIPAddress, FieldImageCount, FieldImageSize, + FieldMediaType, + FieldCacheTTLOverridden, FieldCreatedAt, } @@ -211,6 +217,10 @@ var ( DefaultImageCount int // ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. ImageSizeValidator func(string) error + // MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. + MediaTypeValidator func(string) error + // DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field. + DefaultCacheTTLOverridden bool // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time ) @@ -368,6 +378,16 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImageSize, opts...).ToFunc() } +// ByMediaType orders the results by the media_type field. +func ByMediaType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMediaType, opts...).ToFunc() +} + +// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field. +func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc() +} + // ByCreatedAt orders the results by the created_at field. func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go index 28e2ab4c..af960335 100644 --- a/backend/ent/usagelog/where.go +++ b/backend/ent/usagelog/where.go @@ -200,6 +200,16 @@ func ImageSize(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v)) } +// MediaType applies equality check predicate on the "media_type" field. It's identical to MediaTypeEQ. +func MediaType(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v)) +} + +// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ. +func CacheTTLOverridden(v bool) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v)) +} + // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v)) @@ -1440,6 +1450,91 @@ func ImageSizeContainsFold(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v)) } +// MediaTypeEQ applies the EQ predicate on the "media_type" field. +func MediaTypeEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v)) +} + +// MediaTypeNEQ applies the NEQ predicate on the "media_type" field. +func MediaTypeNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldMediaType, v)) +} + +// MediaTypeIn applies the In predicate on the "media_type" field. +func MediaTypeIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldMediaType, vs...)) +} + +// MediaTypeNotIn applies the NotIn predicate on the "media_type" field. +func MediaTypeNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldMediaType, vs...)) +} + +// MediaTypeGT applies the GT predicate on the "media_type" field. +func MediaTypeGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldMediaType, v)) +} + +// MediaTypeGTE applies the GTE predicate on the "media_type" field. +func MediaTypeGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldMediaType, v)) +} + +// MediaTypeLT applies the LT predicate on the "media_type" field. +func MediaTypeLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldMediaType, v)) +} + +// MediaTypeLTE applies the LTE predicate on the "media_type" field. +func MediaTypeLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldMediaType, v)) +} + +// MediaTypeContains applies the Contains predicate on the "media_type" field. +func MediaTypeContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldMediaType, v)) +} + +// MediaTypeHasPrefix applies the HasPrefix predicate on the "media_type" field. +func MediaTypeHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldMediaType, v)) +} + +// MediaTypeHasSuffix applies the HasSuffix predicate on the "media_type" field. +func MediaTypeHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldMediaType, v)) +} + +// MediaTypeIsNil applies the IsNil predicate on the "media_type" field. +func MediaTypeIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldMediaType)) +} + +// MediaTypeNotNil applies the NotNil predicate on the "media_type" field. +func MediaTypeNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldMediaType)) +} + +// MediaTypeEqualFold applies the EqualFold predicate on the "media_type" field. +func MediaTypeEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldMediaType, v)) +} + +// MediaTypeContainsFold applies the ContainsFold predicate on the "media_type" field. +func MediaTypeContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v)) +} + +// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field. +func CacheTTLOverriddenEQ(v bool) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v)) +} + +// CacheTTLOverriddenNEQ applies the NEQ predicate on the "cache_ttl_overridden" field. +func CacheTTLOverriddenNEQ(v bool) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldCacheTTLOverridden, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldCreatedAt, v)) diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go index a17d6507..e0285a5e 100644 --- a/backend/ent/usagelog_create.go +++ b/backend/ent/usagelog_create.go @@ -393,6 +393,34 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate { return _c } +// SetMediaType sets the "media_type" field. +func (_c *UsageLogCreate) SetMediaType(v string) *UsageLogCreate { + _c.mutation.SetMediaType(v) + return _c +} + +// SetNillableMediaType sets the "media_type" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate { + if v != nil { + _c.SetMediaType(*v) + } + return _c +} + +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate { + _c.mutation.SetCacheTTLOverridden(v) + return _c +} + +// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableCacheTTLOverridden(v *bool) *UsageLogCreate { + if v != nil { + _c.SetCacheTTLOverridden(*v) + } + return _c +} + // SetCreatedAt sets the "created_at" field. func (_c *UsageLogCreate) SetCreatedAt(v time.Time) *UsageLogCreate { _c.mutation.SetCreatedAt(v) @@ -531,6 +559,10 @@ func (_c *UsageLogCreate) defaults() { v := usagelog.DefaultImageCount _c.mutation.SetImageCount(v) } + if _, ok := _c.mutation.CacheTTLOverridden(); !ok { + v := usagelog.DefaultCacheTTLOverridden + _c.mutation.SetCacheTTLOverridden(v) + } if _, ok := _c.mutation.CreatedAt(); !ok { v := usagelog.DefaultCreatedAt() _c.mutation.SetCreatedAt(v) @@ -627,6 +659,14 @@ func (_c *UsageLogCreate) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } + if v, ok := _c.mutation.MediaType(); ok { + if err := usagelog.MediaTypeValidator(v); err != nil { + return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} + } + } + if _, ok := _c.mutation.CacheTTLOverridden(); !ok { + return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)} + } if _, ok := _c.mutation.CreatedAt(); !ok { return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UsageLog.created_at"`)} } @@ -762,6 +802,14 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { _spec.SetField(usagelog.FieldImageSize, field.TypeString, value) _node.ImageSize = &value } + if value, ok := _c.mutation.MediaType(); ok { + _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) + _node.MediaType = &value + } + if value, ok := _c.mutation.CacheTTLOverridden(); ok { + _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) + _node.CacheTTLOverridden = value + } if value, ok := _c.mutation.CreatedAt(); ok { _spec.SetField(usagelog.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = value @@ -1407,6 +1455,36 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert { return u } +// SetMediaType sets the "media_type" field. +func (u *UsageLogUpsert) SetMediaType(v string) *UsageLogUpsert { + u.Set(usagelog.FieldMediaType, v) + return u +} + +// UpdateMediaType sets the "media_type" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateMediaType() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldMediaType) + return u +} + +// ClearMediaType clears the value of the "media_type" field. +func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert { + u.SetNull(usagelog.FieldMediaType) + return u +} + +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert { + u.Set(usagelog.FieldCacheTTLOverridden, v) + return u +} + +// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateCacheTTLOverridden() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldCacheTTLOverridden) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -2040,6 +2118,41 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne { }) } +// SetMediaType sets the "media_type" field. +func (u *UsageLogUpsertOne) SetMediaType(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetMediaType(v) + }) +} + +// UpdateMediaType sets the "media_type" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateMediaType() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateMediaType() + }) +} + +// ClearMediaType clears the value of the "media_type" field. +func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearMediaType() + }) +} + +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheTTLOverridden(v) + }) +} + +// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateCacheTTLOverridden() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheTTLOverridden() + }) +} + // Exec executes the query. func (u *UsageLogUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -2839,6 +2952,41 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk { }) } +// SetMediaType sets the "media_type" field. +func (u *UsageLogUpsertBulk) SetMediaType(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetMediaType(v) + }) +} + +// UpdateMediaType sets the "media_type" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateMediaType() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateMediaType() + }) +} + +// ClearMediaType clears the value of the "media_type" field. +func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearMediaType() + }) +} + +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetCacheTTLOverridden(v) + }) +} + +// UpdateCacheTTLOverridden sets the "cache_ttl_overridden" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateCacheTTLOverridden() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateCacheTTLOverridden() + }) +} + // Exec executes the query. func (u *UsageLogUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go index 571a7b3c..b46e5b56 100644 --- a/backend/ent/usagelog_update.go +++ b/backend/ent/usagelog_update.go @@ -612,6 +612,40 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate { return _u } +// SetMediaType sets the "media_type" field. +func (_u *UsageLogUpdate) SetMediaType(v string) *UsageLogUpdate { + _u.mutation.SetMediaType(v) + return _u +} + +// SetNillableMediaType sets the "media_type" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableMediaType(v *string) *UsageLogUpdate { + if v != nil { + _u.SetMediaType(*v) + } + return _u +} + +// ClearMediaType clears the value of the "media_type" field. +func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate { + _u.mutation.ClearMediaType() + return _u +} + +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate { + _u.mutation.SetCacheTTLOverridden(v) + return _u +} + +// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableCacheTTLOverridden(v *bool) *UsageLogUpdate { + if v != nil { + _u.SetCacheTTLOverridden(*v) + } + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *UsageLogUpdate) SetUser(v *User) *UsageLogUpdate { return _u.SetUserID(v.ID) @@ -726,6 +760,11 @@ func (_u *UsageLogUpdate) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } + if v, ok := _u.mutation.MediaType(); ok { + if err := usagelog.MediaTypeValidator(v); err != nil { + return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} + } + } if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) } @@ -894,6 +933,15 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.ImageSizeCleared() { _spec.ClearField(usagelog.FieldImageSize, field.TypeString) } + if value, ok := _u.mutation.MediaType(); ok { + _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) + } + if _u.mutation.MediaTypeCleared() { + _spec.ClearField(usagelog.FieldMediaType, field.TypeString) + } + if value, ok := _u.mutation.CacheTTLOverridden(); ok { + _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -1639,6 +1687,40 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne { return _u } +// SetMediaType sets the "media_type" field. +func (_u *UsageLogUpdateOne) SetMediaType(v string) *UsageLogUpdateOne { + _u.mutation.SetMediaType(v) + return _u +} + +// SetNillableMediaType sets the "media_type" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableMediaType(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetMediaType(*v) + } + return _u +} + +// ClearMediaType clears the value of the "media_type" field. +func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne { + _u.mutation.ClearMediaType() + return _u +} + +// SetCacheTTLOverridden sets the "cache_ttl_overridden" field. +func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne { + _u.mutation.SetCacheTTLOverridden(v) + return _u +} + +// SetNillableCacheTTLOverridden sets the "cache_ttl_overridden" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableCacheTTLOverridden(v *bool) *UsageLogUpdateOne { + if v != nil { + _u.SetCacheTTLOverridden(*v) + } + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *UsageLogUpdateOne) SetUser(v *User) *UsageLogUpdateOne { return _u.SetUserID(v.ID) @@ -1766,6 +1848,11 @@ func (_u *UsageLogUpdateOne) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } + if v, ok := _u.mutation.MediaType(); ok { + if err := usagelog.MediaTypeValidator(v); err != nil { + return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} + } + } if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) } @@ -1951,6 +2038,15 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err if _u.mutation.ImageSizeCleared() { _spec.ClearField(usagelog.FieldImageSize, field.TypeString) } + if value, ok := _u.mutation.MediaType(); ok { + _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) + } + if _u.mutation.MediaTypeCleared() { + _spec.ClearField(usagelog.FieldMediaType, field.TypeString) + } + if value, ok := _u.mutation.CacheTTLOverridden(); ok { + _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, diff --git a/backend/ent/user.go b/backend/ent/user.go index 2435aa1b..b3f933f6 100644 --- a/backend/ent/user.go +++ b/backend/ent/user.go @@ -45,6 +45,10 @@ type User struct { TotpEnabled bool `json:"totp_enabled,omitempty"` // TotpEnabledAt holds the value of the "totp_enabled_at" field. TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"` + // SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field. + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"` + // SoraStorageUsedBytes holds the value of the "sora_storage_used_bytes" field. + SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the UserQuery when eager-loading is set. Edges UserEdges `json:"edges"` @@ -177,7 +181,7 @@ func (*User) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case user.FieldBalance: values[i] = new(sql.NullFloat64) - case user.FieldID, user.FieldConcurrency: + case user.FieldID, user.FieldConcurrency, user.FieldSoraStorageQuotaBytes, user.FieldSoraStorageUsedBytes: values[i] = new(sql.NullInt64) case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted: values[i] = new(sql.NullString) @@ -291,6 +295,18 @@ func (_m *User) assignValues(columns []string, values []any) error { _m.TotpEnabledAt = new(time.Time) *_m.TotpEnabledAt = value.Time } + case user.FieldSoraStorageQuotaBytes: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i]) + } else if value.Valid { + _m.SoraStorageQuotaBytes = value.Int64 + } + case user.FieldSoraStorageUsedBytes: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field sora_storage_used_bytes", values[i]) + } else if value.Valid { + _m.SoraStorageUsedBytes = value.Int64 + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -424,6 +440,12 @@ func (_m *User) String() string { builder.WriteString("totp_enabled_at=") builder.WriteString(v.Format(time.ANSIC)) } + builder.WriteString(", ") + builder.WriteString("sora_storage_quota_bytes=") + builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes)) + builder.WriteString(", ") + builder.WriteString("sora_storage_used_bytes=") + builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageUsedBytes)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go index ae9418ff..155b9160 100644 --- a/backend/ent/user/user.go +++ b/backend/ent/user/user.go @@ -43,6 +43,10 @@ const ( FieldTotpEnabled = "totp_enabled" // FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database. FieldTotpEnabledAt = "totp_enabled_at" + // FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database. + FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes" + // FieldSoraStorageUsedBytes holds the string denoting the sora_storage_used_bytes field in the database. + FieldSoraStorageUsedBytes = "sora_storage_used_bytes" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -152,6 +156,8 @@ var Columns = []string{ FieldTotpSecretEncrypted, FieldTotpEnabled, FieldTotpEnabledAt, + FieldSoraStorageQuotaBytes, + FieldSoraStorageUsedBytes, } var ( @@ -208,6 +214,10 @@ var ( DefaultNotes string // DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field. DefaultTotpEnabled bool + // DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field. + DefaultSoraStorageQuotaBytes int64 + // DefaultSoraStorageUsedBytes holds the default value on creation for the "sora_storage_used_bytes" field. + DefaultSoraStorageUsedBytes int64 ) // OrderOption defines the ordering options for the User queries. @@ -288,6 +298,16 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc() } +// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field. +func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc() +} + +// BySoraStorageUsedBytes orders the results by the sora_storage_used_bytes field. +func BySoraStorageUsedBytes(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraStorageUsedBytes, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go index 1de61037..e26afcf3 100644 --- a/backend/ent/user/where.go +++ b/backend/ent/user/where.go @@ -125,6 +125,16 @@ func TotpEnabledAt(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v)) } +// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ. +func SoraStorageQuotaBytes(v int64) predicate.User { + return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageUsedBytes applies equality check predicate on the "sora_storage_used_bytes" field. It's identical to SoraStorageUsedBytesEQ. +func SoraStorageUsedBytes(v int64) predicate.User { + return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldCreatedAt, v)) @@ -860,6 +870,86 @@ func TotpEnabledAtNotNil() predicate.User { return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt)) } +// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesEQ(v int64) predicate.User { + return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesNEQ(v int64) predicate.User { + return predicate.User(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesIn(vs ...int64) predicate.User { + return predicate.User(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...)) +} + +// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.User { + return predicate.User(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...)) +} + +// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesGT(v int64) predicate.User { + return predicate.User(sql.FieldGT(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesGTE(v int64) predicate.User { + return predicate.User(sql.FieldGTE(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesLT(v int64) predicate.User { + return predicate.User(sql.FieldLT(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesLTE(v int64) predicate.User { + return predicate.User(sql.FieldLTE(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageUsedBytesEQ applies the EQ predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesEQ(v int64) predicate.User { + return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v)) +} + +// SoraStorageUsedBytesNEQ applies the NEQ predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesNEQ(v int64) predicate.User { + return predicate.User(sql.FieldNEQ(FieldSoraStorageUsedBytes, v)) +} + +// SoraStorageUsedBytesIn applies the In predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesIn(vs ...int64) predicate.User { + return predicate.User(sql.FieldIn(FieldSoraStorageUsedBytes, vs...)) +} + +// SoraStorageUsedBytesNotIn applies the NotIn predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesNotIn(vs ...int64) predicate.User { + return predicate.User(sql.FieldNotIn(FieldSoraStorageUsedBytes, vs...)) +} + +// SoraStorageUsedBytesGT applies the GT predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesGT(v int64) predicate.User { + return predicate.User(sql.FieldGT(FieldSoraStorageUsedBytes, v)) +} + +// SoraStorageUsedBytesGTE applies the GTE predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesGTE(v int64) predicate.User { + return predicate.User(sql.FieldGTE(FieldSoraStorageUsedBytes, v)) +} + +// SoraStorageUsedBytesLT applies the LT predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesLT(v int64) predicate.User { + return predicate.User(sql.FieldLT(FieldSoraStorageUsedBytes, v)) +} + +// SoraStorageUsedBytesLTE applies the LTE predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesLTE(v int64) predicate.User { + return predicate.User(sql.FieldLTE(FieldSoraStorageUsedBytes, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go index f862a580..df0c6bcc 100644 --- a/backend/ent/user_create.go +++ b/backend/ent/user_create.go @@ -210,6 +210,34 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate { return _c } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_c *UserCreate) SetSoraStorageQuotaBytes(v int64) *UserCreate { + _c.mutation.SetSoraStorageQuotaBytes(v) + return _c +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_c *UserCreate) SetNillableSoraStorageQuotaBytes(v *int64) *UserCreate { + if v != nil { + _c.SetSoraStorageQuotaBytes(*v) + } + return _c +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (_c *UserCreate) SetSoraStorageUsedBytes(v int64) *UserCreate { + _c.mutation.SetSoraStorageUsedBytes(v) + return _c +} + +// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil. +func (_c *UserCreate) SetNillableSoraStorageUsedBytes(v *int64) *UserCreate { + if v != nil { + _c.SetSoraStorageUsedBytes(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -424,6 +452,14 @@ func (_c *UserCreate) defaults() error { v := user.DefaultTotpEnabled _c.mutation.SetTotpEnabled(v) } + if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { + v := user.DefaultSoraStorageQuotaBytes + _c.mutation.SetSoraStorageQuotaBytes(v) + } + if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok { + v := user.DefaultSoraStorageUsedBytes + _c.mutation.SetSoraStorageUsedBytes(v) + } return nil } @@ -487,6 +523,12 @@ func (_c *UserCreate) check() error { if _, ok := _c.mutation.TotpEnabled(); !ok { return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)} } + if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { + return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "User.sora_storage_quota_bytes"`)} + } + if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok { + return &ValidationError{Name: "sora_storage_used_bytes", err: errors.New(`ent: missing required field "User.sora_storage_used_bytes"`)} + } return nil } @@ -570,6 +612,14 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value) _node.TotpEnabledAt = &value } + if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + _node.SoraStorageQuotaBytes = value + } + if value, ok := _c.mutation.SoraStorageUsedBytes(); ok { + _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) + _node.SoraStorageUsedBytes = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -956,6 +1006,42 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert { return u } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *UserUpsert) SetSoraStorageQuotaBytes(v int64) *UserUpsert { + u.Set(user.FieldSoraStorageQuotaBytes, v) + return u +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *UserUpsert) UpdateSoraStorageQuotaBytes() *UserUpsert { + u.SetExcluded(user.FieldSoraStorageQuotaBytes) + return u +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *UserUpsert) AddSoraStorageQuotaBytes(v int64) *UserUpsert { + u.Add(user.FieldSoraStorageQuotaBytes, v) + return u +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (u *UserUpsert) SetSoraStorageUsedBytes(v int64) *UserUpsert { + u.Set(user.FieldSoraStorageUsedBytes, v) + return u +} + +// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create. +func (u *UserUpsert) UpdateSoraStorageUsedBytes() *UserUpsert { + u.SetExcluded(user.FieldSoraStorageUsedBytes) + return u +} + +// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field. +func (u *UserUpsert) AddSoraStorageUsedBytes(v int64) *UserUpsert { + u.Add(user.FieldSoraStorageUsedBytes, v) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -1218,6 +1304,48 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne { }) } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *UserUpsertOne) SetSoraStorageQuotaBytes(v int64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetSoraStorageQuotaBytes(v) + }) +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *UserUpsertOne) AddSoraStorageQuotaBytes(v int64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.AddSoraStorageQuotaBytes(v) + }) +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateSoraStorageQuotaBytes() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateSoraStorageQuotaBytes() + }) +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (u *UserUpsertOne) SetSoraStorageUsedBytes(v int64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetSoraStorageUsedBytes(v) + }) +} + +// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field. +func (u *UserUpsertOne) AddSoraStorageUsedBytes(v int64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.AddSoraStorageUsedBytes(v) + }) +} + +// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateSoraStorageUsedBytes() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateSoraStorageUsedBytes() + }) +} + // Exec executes the query. func (u *UserUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1646,6 +1774,48 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk { }) } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *UserUpsertBulk) SetSoraStorageQuotaBytes(v int64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetSoraStorageQuotaBytes(v) + }) +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *UserUpsertBulk) AddSoraStorageQuotaBytes(v int64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.AddSoraStorageQuotaBytes(v) + }) +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateSoraStorageQuotaBytes() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateSoraStorageQuotaBytes() + }) +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (u *UserUpsertBulk) SetSoraStorageUsedBytes(v int64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetSoraStorageUsedBytes(v) + }) +} + +// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field. +func (u *UserUpsertBulk) AddSoraStorageUsedBytes(v int64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.AddSoraStorageUsedBytes(v) + }) +} + +// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateSoraStorageUsedBytes() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateSoraStorageUsedBytes() + }) +} + // Exec executes the query. func (u *UserUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go index 80222c92..f71f0cad 100644 --- a/backend/ent/user_update.go +++ b/backend/ent/user_update.go @@ -242,6 +242,48 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate { return _u } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_u *UserUpdate) SetSoraStorageQuotaBytes(v int64) *UserUpdate { + _u.mutation.ResetSoraStorageQuotaBytes() + _u.mutation.SetSoraStorageQuotaBytes(v) + return _u +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_u *UserUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdate { + if v != nil { + _u.SetSoraStorageQuotaBytes(*v) + } + return _u +} + +// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. +func (_u *UserUpdate) AddSoraStorageQuotaBytes(v int64) *UserUpdate { + _u.mutation.AddSoraStorageQuotaBytes(v) + return _u +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (_u *UserUpdate) SetSoraStorageUsedBytes(v int64) *UserUpdate { + _u.mutation.ResetSoraStorageUsedBytes() + _u.mutation.SetSoraStorageUsedBytes(v) + return _u +} + +// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil. +func (_u *UserUpdate) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdate { + if v != nil { + _u.SetSoraStorageUsedBytes(*v) + } + return _u +} + +// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field. +func (_u *UserUpdate) AddSoraStorageUsedBytes(v int64) *UserUpdate { + _u.mutation.AddSoraStorageUsedBytes(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -709,6 +751,18 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } + if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { + _spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.SoraStorageUsedBytes(); ok { + _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok { + _spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1352,6 +1406,48 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne { return _u } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_u *UserUpdateOne) SetSoraStorageQuotaBytes(v int64) *UserUpdateOne { + _u.mutation.ResetSoraStorageQuotaBytes() + _u.mutation.SetSoraStorageQuotaBytes(v) + return _u +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdateOne { + if v != nil { + _u.SetSoraStorageQuotaBytes(*v) + } + return _u +} + +// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. +func (_u *UserUpdateOne) AddSoraStorageQuotaBytes(v int64) *UserUpdateOne { + _u.mutation.AddSoraStorageQuotaBytes(v) + return _u +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (_u *UserUpdateOne) SetSoraStorageUsedBytes(v int64) *UserUpdateOne { + _u.mutation.ResetSoraStorageUsedBytes() + _u.mutation.SetSoraStorageUsedBytes(v) + return _u +} + +// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdateOne { + if v != nil { + _u.SetSoraStorageUsedBytes(*v) + } + return _u +} + +// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field. +func (_u *UserUpdateOne) AddSoraStorageUsedBytes(v int64) *UserUpdateOne { + _u.mutation.AddSoraStorageUsedBytes(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -1849,6 +1945,18 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } + if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { + _spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.SoraStorageUsedBytes(); ok { + _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok { + _spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/go.mod b/backend/go.mod index 08d54b91..ab76258a 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -5,6 +5,13 @@ go 1.25.7 require ( entgo.io/ent v0.14.5 github.com/DATA-DOG/go-sqlmock v1.5.2 + github.com/DouDOU-start/go-sora2api v1.1.0 + github.com/alitto/pond/v2 v2.6.2 + github.com/aws/aws-sdk-go-v2/config v1.32.10 + github.com/aws/aws-sdk-go-v2/credentials v1.19.10 + github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2 + github.com/cespare/xxhash/v2 v2.3.0 + github.com/coder/websocket v1.8.14 github.com/dgraph-io/ristretto v0.2.0 github.com/gin-gonic/gin v1.9.1 github.com/golang-jwt/jwt/v5 v5.2.2 @@ -13,9 +20,10 @@ require ( github.com/gorilla/websocket v1.5.3 github.com/imroc/req/v3 v3.57.0 github.com/lib/pq v1.10.9 + github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pquerna/otp v1.5.0 github.com/redis/go-redis/v9 v9.17.2 - github.com/refraction-networking/utls v1.8.1 + github.com/refraction-networking/utls v1.8.2 github.com/robfig/cron/v3 v3.0.1 github.com/shirou/gopsutil/v4 v4.25.6 github.com/spf13/viper v1.18.2 @@ -25,10 +33,14 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/zeromicro/go-zero v1.9.4 - golang.org/x/crypto v0.47.0 + go.uber.org/zap v1.24.0 + golang.org/x/crypto v0.48.0 golang.org/x/net v0.49.0 golang.org/x/sync v0.19.0 - golang.org/x/term v0.39.0 + golang.org/x/term v0.40.0 + google.golang.org/grpc v1.75.1 + google.golang.org/protobuf v1.36.10 + gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 modernc.org/sqlite v1.44.3 ) @@ -41,11 +53,33 @@ require ( github.com/agext/levenshtein v1.2.3 // indirect github.com/andybalholm/brotli v1.2.0 // indirect github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.41.2 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.6 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect + github.com/aws/smithy-go v1.24.1 // indirect + github.com/bdandy/go-errors v1.2.2 // indirect + github.com/bdandy/go-socks4 v1.2.3 // indirect github.com/bmatcuk/doublestar v1.3.4 // indirect + github.com/bogdanfinn/fhttp v0.6.8 // indirect + github.com/bogdanfinn/quic-go-utls v1.0.9-utls // indirect + github.com/bogdanfinn/tls-client v1.14.0 // indirect + github.com/bogdanfinn/utls v1.7.7-barnius // indirect + github.com/bogdanfinn/websocket v1.5.5-barnius // indirect github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect - github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/containerd/errdefs v1.0.0 // indirect github.com/containerd/errdefs/pkg v0.3.0 // indirect @@ -75,6 +109,7 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-querystring v1.1.0 // indirect + github.com/google/subcommands v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl/v2 v2.18.1 // indirect @@ -103,7 +138,6 @@ require ( github.com/ncruces/go-strftime v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect - github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect @@ -120,6 +154,7 @@ require ( github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 // indirect github.com/testcontainers/testcontainers-go v0.40.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect @@ -134,18 +169,17 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect go.opentelemetry.io/otel v1.37.0 // indirect go.opentelemetry.io/otel/metric v1.37.0 // indirect - go.opentelemetry.io/otel/sdk v1.37.0 // indirect go.opentelemetry.io/otel/trace v1.37.0 // indirect go.uber.org/atomic v1.10.0 // indirect go.uber.org/automaxprocs v1.6.0 // indirect go.uber.org/multierr v1.9.0 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect - golang.org/x/mod v0.31.0 // indirect - golang.org/x/sys v0.40.0 // indirect - golang.org/x/text v0.33.0 // indirect - google.golang.org/grpc v1.75.1 // indirect - google.golang.org/protobuf v1.36.10 // indirect + golang.org/x/mod v0.32.0 // indirect + golang.org/x/sys v0.41.0 // indirect + golang.org/x/text v0.34.0 // indirect + golang.org/x/tools v0.41.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect gopkg.in/ini.v1 v1.67.0 // indirect modernc.org/libc v1.67.6 // indirect modernc.org/mathutil v1.7.1 // indirect diff --git a/backend/go.sum b/backend/go.sum index 71e8f504..32e389a7 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -10,16 +10,74 @@ github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOEl github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/DouDOU-start/go-sora2api v1.1.0 h1:PxWiukK77StiHxEngOFwT1rKUn9oTAJJTl07wQUXwiU= +github.com/DouDOU-start/go-sora2api v1.1.0/go.mod h1:dcwpethoKfAsMWskDD9iGgc/3yox2tkthPLSMVGnhkE= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo= github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= +github.com/alitto/pond/v2 v2.6.2 h1:Sphe40g0ILeM1pA2c2K+Th0DGU+pt0A/Kprr+WB24Pw= +github.com/alitto/pond/v2 v2.6.2/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY= github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4= +github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls= +github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c= +github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI= +github.com/aws/aws-sdk-go-v2/config v1.32.10/go.mod h1:2rUIOnA2JaiqYmSKYmRJlcMWy6qTj1vuRFscppSBMcw= +github.com/aws/aws-sdk-go-v2/credentials v1.19.10 h1:EEhmEUFCE1Yhl7vDhNOI5OCL/iKMdkkYFTRpZXNw7m8= +github.com/aws/aws-sdk-go-v2/credentials v1.19.10/go.mod h1:RnnlFCAlxQCkN2Q379B67USkBMu1PipEEiibzYN5UTE= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 h1:Ii4s+Sq3yDfaMLpjrJsqD6SmG/Wq/P5L/hw2qa78UAY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18/go.mod h1:6x81qnY++ovptLE6nWQeWrpXxbnlIex+4H4eYYGcqfc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 h1:F43zk1vemYIqPAwhjTjYIz0irU2EY7sOb/F5eJ3HuyM= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18/go.mod h1:w1jdlZXrGKaJcNoL+Nnrj+k5wlpGXqnNrKoP22HvAug= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 h1:xCeWVjj0ki0l3nruoyP2slHsGArMxeiiaoPN5QZH6YQ= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18/go.mod h1:r/eLGuGCBw6l36ZRWiw6PaZwPXb6YOj+i/7MizNl5/k= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18 h1:eZioDaZGJ0tMM4gzmkNIO2aAoQd+je7Ug7TkvAzlmkU= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18/go.mod h1:CCXwUKAJdoWr6/NcxZ+zsiPr6oH/Q5aTooRGYieAyj4= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5 h1:CeY9LUdur+Dxoeldqoun6y4WtJ3RQtzk0JMP2gfUay0= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5/go.mod h1:AZLZf2fMaahW5s/wMRciu1sYbdsikT/UHwbUjOdEVTc= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10 h1:fJvQ5mIBVfKtiyx0AHY6HeWcRX5LGANLpq8SVR+Uazs= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10/go.mod h1:Kzm5e6OmNH8VMkgK9t+ry5jEih4Y8whqs+1hrkxim1I= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18 h1:LTRCYFlnnKFlKsyIQxKhJuDuA3ZkrDQMRYm6rXiHlLY= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18/go.mod h1:XhwkgGG6bHSd00nO/mexWTcTjgd6PjuvWQMqSn2UaEk= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18 h1:/A/xDuZAVD2BpsS2fftFRo/NoEKQJ8YTnJDEHBy2Gtg= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18/go.mod h1:hWe9b4f+djUQGmyiGEeOnZv69dtMSgpDRIvNMvuvzvY= +github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2 h1:M1A9AjcFwlxTLuf0Faj88L8Iqw0n/AJHjpZTQzMMsSc= +github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2/go.mod h1:KsdTV6Q9WKUZm2mNJnUFmIoXfZux91M3sr/a4REX8e0= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.6 h1:MzORe+J94I+hYu2a6XmV5yC9huoTv8NRcCrUNedDypQ= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.6/go.mod h1:hXzcHLARD7GeWnifd8j9RWqtfIgxj4/cAtIVIK7hg8g= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 h1:7oGD8KPfBOJGXiCoRKrrrQkbvCp8N++u36hrLMPey6o= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.11/go.mod h1:0DO9B5EUJQlIDif+XJRWCljZRKsAFKh3gpFz7UnDtOo= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 h1:edCcNp9eGIUDUCrzoCu1jWAXLGFIizeqkdkKgRlJwWc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs= +github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0= +github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q= +github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM= +github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic= +github.com/bdandy/go-socks4 v1.2.3/go.mod h1:98kiVFgpdogR8aIGLWLvjDVZ8XcKPsSI/ypGrO+bqHI= +github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= +github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0= github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE= +github.com/bogdanfinn/fhttp v0.6.8 h1:LiQyHOY3i0QoxxNB7nq27/nGNNbtPj0fuBPozhR7Ws4= +github.com/bogdanfinn/fhttp v0.6.8/go.mod h1:A+EKDzMx2hb4IUbMx4TlkoHnaJEiLl8r/1Ss1Y+5e5M= +github.com/bogdanfinn/quic-go-utls v1.0.9-utls h1:tV6eDEiRbRCcepALSzxR94JUVD3N3ACIiRLgyc2Ep8s= +github.com/bogdanfinn/quic-go-utls v1.0.9-utls/go.mod h1:aHph9B9H9yPOt5xnhWKSOum27DJAqpiHzwX+gjvaXcg= +github.com/bogdanfinn/tls-client v1.14.0 h1:vyk7Cn4BIvLAGVuMfb0tP22OqogfO1lYamquQNEZU1A= +github.com/bogdanfinn/tls-client v1.14.0/go.mod h1:LsU6mXVn8MOFDwTkyRfI7V1BZM1p0wf2ZfZsICW/1fM= +github.com/bogdanfinn/utls v1.7.7-barnius h1:OuJ497cc7F3yKNVHRsYPQdGggmk5x6+V5ZlrCR7fOLU= +github.com/bogdanfinn/utls v1.7.7-barnius/go.mod h1:aAK1VZQlpKZClF1WEQeq6kyclbkPq4hz6xTbB5xSlmg= +github.com/bogdanfinn/websocket v1.5.5-barnius h1:bY+qnxpai1qe7Jmjx+Sds/cmOSpuuLoR8x61rWltjOI= +github.com/bogdanfinn/websocket v1.5.5-barnius/go.mod h1:gvvEw6pTKHb7yOiFvIfAFTStQWyrm25BMVCTj5wRSsI= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= @@ -36,6 +94,12 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= +github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= +github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U= +github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE= @@ -107,6 +171,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -116,6 +182,8 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= +github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= @@ -135,8 +203,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI= github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00= -github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= -github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -172,8 +238,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= -github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= +github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= @@ -234,12 +300,10 @@ github.com/quic-go/quic-go v0.57.1 h1:25KAAR9QR8KZrCZRThWMKVAwGoiHIrNbT72ULHTuI1 github.com/quic-go/quic-go v0.57.1/go.mod h1:ly4QBAjHA2VhdnxhojRsCUOeJwKYg+taDlos92xb1+s= github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI= github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= -github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo= -github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= +github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo= +github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= @@ -285,6 +349,8 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc= +github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5/go.mod h1:2JjD2zLQYH5HO74y5+aE3remJQvl6q4Sn6aWA2wD1Ng= github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU= github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY= github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk= @@ -332,6 +398,8 @@ go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/Wgbsd go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= +go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc= +go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps= go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= @@ -340,25 +408,32 @@ go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= +go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= -golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= -golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= -golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= +golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= +golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= +golang.org/x/net v0.0.0-20211104170005-ce137452f963/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -366,17 +441,22 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= -golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= -golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= -golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= -golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= +golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= -golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= -golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= +golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk= @@ -391,6 +471,8 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 91437ba8..54be38a1 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -5,7 +5,7 @@ import ( "crypto/rand" "encoding/hex" "fmt" - "log" + "log/slog" "net/url" "os" "strings" @@ -19,10 +19,25 @@ const ( RunModeSimple = "simple" ) +// 使用量记录队列溢出策略 +const ( + UsageRecordOverflowPolicyDrop = "drop" + UsageRecordOverflowPolicySample = "sample" + UsageRecordOverflowPolicySync = "sync" +) + // DefaultCSPPolicy is the default Content-Security-Policy with nonce support // __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'" +// UMQ(用户消息队列)模式常量 +const ( + // UMQModeSerialize: 账号级串行锁 + RPM 自适应延迟 + UMQModeSerialize = "serialize" + // UMQModeThrottle: 仅 RPM 自适应前置延迟,不阻塞并发 + UMQModeThrottle = "throttle" +) + // 连接池隔离策略常量 // 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗 const ( @@ -38,31 +53,68 @@ const ( ) type Config struct { - Server ServerConfig `mapstructure:"server"` - CORS CORSConfig `mapstructure:"cors"` - Security SecurityConfig `mapstructure:"security"` - Billing BillingConfig `mapstructure:"billing"` - Turnstile TurnstileConfig `mapstructure:"turnstile"` - Database DatabaseConfig `mapstructure:"database"` - Redis RedisConfig `mapstructure:"redis"` - Ops OpsConfig `mapstructure:"ops"` - JWT JWTConfig `mapstructure:"jwt"` - Totp TotpConfig `mapstructure:"totp"` - LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` - Default DefaultConfig `mapstructure:"default"` - RateLimit RateLimitConfig `mapstructure:"rate_limit"` - Pricing PricingConfig `mapstructure:"pricing"` - Gateway GatewayConfig `mapstructure:"gateway"` - APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` - Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` - DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` - UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` - Concurrency ConcurrencyConfig `mapstructure:"concurrency"` - TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` - RunMode string `mapstructure:"run_mode" yaml:"run_mode"` - Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" - Gemini GeminiConfig `mapstructure:"gemini"` - Update UpdateConfig `mapstructure:"update"` + Server ServerConfig `mapstructure:"server"` + Log LogConfig `mapstructure:"log"` + CORS CORSConfig `mapstructure:"cors"` + Security SecurityConfig `mapstructure:"security"` + Billing BillingConfig `mapstructure:"billing"` + Turnstile TurnstileConfig `mapstructure:"turnstile"` + Database DatabaseConfig `mapstructure:"database"` + Redis RedisConfig `mapstructure:"redis"` + Ops OpsConfig `mapstructure:"ops"` + JWT JWTConfig `mapstructure:"jwt"` + Totp TotpConfig `mapstructure:"totp"` + LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` + Default DefaultConfig `mapstructure:"default"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` + Pricing PricingConfig `mapstructure:"pricing"` + Gateway GatewayConfig `mapstructure:"gateway"` + APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` + SubscriptionCache SubscriptionCacheConfig `mapstructure:"subscription_cache"` + SubscriptionMaintenance SubscriptionMaintenanceConfig `mapstructure:"subscription_maintenance"` + Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` + DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` + UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` + Concurrency ConcurrencyConfig `mapstructure:"concurrency"` + TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + Sora SoraConfig `mapstructure:"sora"` + RunMode string `mapstructure:"run_mode" yaml:"run_mode"` + Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" + Gemini GeminiConfig `mapstructure:"gemini"` + Update UpdateConfig `mapstructure:"update"` + Idempotency IdempotencyConfig `mapstructure:"idempotency"` +} + +type LogConfig struct { + Level string `mapstructure:"level"` + Format string `mapstructure:"format"` + ServiceName string `mapstructure:"service_name"` + Environment string `mapstructure:"env"` + Caller bool `mapstructure:"caller"` + StacktraceLevel string `mapstructure:"stacktrace_level"` + Output LogOutputConfig `mapstructure:"output"` + Rotation LogRotationConfig `mapstructure:"rotation"` + Sampling LogSamplingConfig `mapstructure:"sampling"` +} + +type LogOutputConfig struct { + ToStdout bool `mapstructure:"to_stdout"` + ToFile bool `mapstructure:"to_file"` + FilePath string `mapstructure:"file_path"` +} + +type LogRotationConfig struct { + MaxSizeMB int `mapstructure:"max_size_mb"` + MaxBackups int `mapstructure:"max_backups"` + MaxAgeDays int `mapstructure:"max_age_days"` + Compress bool `mapstructure:"compress"` + LocalTime bool `mapstructure:"local_time"` +} + +type LogSamplingConfig struct { + Enabled bool `mapstructure:"enabled"` + Initial int `mapstructure:"initial"` + Thereafter int `mapstructure:"thereafter"` } type GeminiConfig struct { @@ -94,6 +146,25 @@ type UpdateConfig struct { ProxyURL string `mapstructure:"proxy_url"` } +type IdempotencyConfig struct { + // ObserveOnly 为 true 时处于观察期:未携带 Idempotency-Key 的请求继续放行。 + ObserveOnly bool `mapstructure:"observe_only"` + // DefaultTTLSeconds 关键写接口的幂等记录默认 TTL(秒)。 + DefaultTTLSeconds int `mapstructure:"default_ttl_seconds"` + // SystemOperationTTLSeconds 系统操作接口的幂等记录 TTL(秒)。 + SystemOperationTTLSeconds int `mapstructure:"system_operation_ttl_seconds"` + // ProcessingTimeoutSeconds processing 状态锁超时(秒)。 + ProcessingTimeoutSeconds int `mapstructure:"processing_timeout_seconds"` + // FailedRetryBackoffSeconds 失败退避窗口(秒)。 + FailedRetryBackoffSeconds int `mapstructure:"failed_retry_backoff_seconds"` + // MaxStoredResponseLen 持久化响应体最大长度(字节)。 + MaxStoredResponseLen int `mapstructure:"max_stored_response_len"` + // CleanupIntervalSeconds 过期记录清理周期(秒)。 + CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"` + // CleanupBatchSize 每次清理的最大记录数。 + CleanupBatchSize int `mapstructure:"cleanup_batch_size"` +} + type LinuxDoConnectConfig struct { Enabled bool `mapstructure:"enabled"` ClientID string `mapstructure:"client_id"` @@ -126,6 +197,8 @@ type TokenRefreshConfig struct { MaxRetries int `mapstructure:"max_retries"` // 重试退避基础时间(秒) RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"` + // 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token(默认关闭) + SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"` } type PricingConfig struct { @@ -147,6 +220,7 @@ type ServerConfig struct { Host string `mapstructure:"host"` Port int `mapstructure:"port"` Mode string `mapstructure:"mode"` // debug/release + FrontendURL string `mapstructure:"frontend_url"` // 前端基础 URL,用于生成邮件中的外部链接 ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP) @@ -173,6 +247,7 @@ type SecurityConfig struct { URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"` ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"` CSP CSPConfig `mapstructure:"csp"` + ProxyFallback ProxyFallbackConfig `mapstructure:"proxy_fallback"` ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"` } @@ -197,6 +272,17 @@ type CSPConfig struct { Policy string `mapstructure:"policy"` } +type ProxyFallbackConfig struct { + // AllowDirectOnError 当辅助服务的代理初始化失败时是否允许回退直连。 + // 仅影响以下非 AI 账号连接的辅助服务: + // - GitHub Release 更新检查 + // - 定价数据拉取 + // 不影响 AI 账号网关连接(Claude/OpenAI/Gemini/Antigravity), + // 这些关键路径的代理失败始终返回错误,不会回退直连。 + // 默认 false:避免因代理配置错误导致服务器真实 IP 泄露。 + AllowDirectOnError bool `mapstructure:"allow_direct_on_error"` +} + type ProxyProbeConfig struct { InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"` // 已禁用:禁止跳过 TLS 证书验证 } @@ -217,6 +303,59 @@ type ConcurrencyConfig struct { PingInterval int `mapstructure:"ping_interval"` } +// SoraConfig 直连 Sora 配置 +type SoraConfig struct { + Client SoraClientConfig `mapstructure:"client"` + Storage SoraStorageConfig `mapstructure:"storage"` +} + +// SoraClientConfig 直连 Sora 客户端配置 +type SoraClientConfig struct { + BaseURL string `mapstructure:"base_url"` + TimeoutSeconds int `mapstructure:"timeout_seconds"` + MaxRetries int `mapstructure:"max_retries"` + CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"` + PollIntervalSeconds int `mapstructure:"poll_interval_seconds"` + MaxPollAttempts int `mapstructure:"max_poll_attempts"` + RecentTaskLimit int `mapstructure:"recent_task_limit"` + RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"` + Debug bool `mapstructure:"debug"` + UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"` + Headers map[string]string `mapstructure:"headers"` + UserAgent string `mapstructure:"user_agent"` + DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"` + CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"` +} + +// SoraCurlCFFISidecarConfig Sora 专用 curl_cffi sidecar 配置 +type SoraCurlCFFISidecarConfig struct { + Enabled bool `mapstructure:"enabled"` + BaseURL string `mapstructure:"base_url"` + Impersonate string `mapstructure:"impersonate"` + TimeoutSeconds int `mapstructure:"timeout_seconds"` + SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"` + SessionTTLSeconds int `mapstructure:"session_ttl_seconds"` +} + +// SoraStorageConfig 媒体存储配置 +type SoraStorageConfig struct { + Type string `mapstructure:"type"` + LocalPath string `mapstructure:"local_path"` + FallbackToUpstream bool `mapstructure:"fallback_to_upstream"` + MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"` + DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"` + MaxDownloadBytes int64 `mapstructure:"max_download_bytes"` + Debug bool `mapstructure:"debug"` + Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"` +} + +// SoraStorageCleanupConfig 媒体清理配置 +type SoraStorageCleanupConfig struct { + Enabled bool `mapstructure:"enabled"` + Schedule string `mapstructure:"schedule"` + RetentionDays int `mapstructure:"retention_days"` +} + // GatewayConfig API网关相关配置 type GatewayConfig struct { // 等待上游响应头的超时时间(秒),0表示无超时 @@ -224,8 +363,22 @@ type GatewayConfig struct { ResponseHeaderTimeout int `mapstructure:"response_header_timeout"` // 请求体最大字节数,用于网关请求体大小限制 MaxBodySize int64 `mapstructure:"max_body_size"` + // 非流式上游响应体读取上限(字节),用于防止无界读取导致内存放大 + UpstreamResponseReadMaxBytes int64 `mapstructure:"upstream_response_read_max_bytes"` + // 代理探测响应体读取上限(字节) + ProxyProbeResponseReadMaxBytes int64 `mapstructure:"proxy_probe_response_read_max_bytes"` + // Gemini 上游响应头调试日志开关(默认关闭,避免高频日志开销) + GeminiDebugResponseHeaders bool `mapstructure:"gemini_debug_response_headers"` // ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy) ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"` + // ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。 + // 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。 + ForceCodexCLI bool `mapstructure:"force_codex_cli"` + // OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头 + // 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。 + OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"` + // OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP) + OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"` // HTTP 上游连接池配置(性能优化:支持高并发场景调优) // MaxIdleConns: 所有主机的最大空闲连接总数 @@ -271,6 +424,24 @@ type GatewayConfig struct { // 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义) FailoverOn400 bool `mapstructure:"failover_on_400"` + // Sora 专用配置 + // SoraMaxBodySize: Sora 请求体最大字节数(0 表示使用 gateway.max_body_size) + SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"` + // SoraStreamTimeoutSeconds: Sora 流式请求总超时(秒,0 表示不限制) + SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"` + // SoraRequestTimeoutSeconds: Sora 非流式请求超时(秒,0 表示不限制) + SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"` + // SoraStreamMode: stream 强制策略(force/error) + SoraStreamMode string `mapstructure:"sora_stream_mode"` + // SoraModelFilters: 模型列表过滤配置 + SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"` + // SoraMediaRequireAPIKey: 是否要求访问 /sora/media 携带 API Key + SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"` + // SoraMediaSigningKey: /sora/media 临时签名密钥(空表示禁用签名) + SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"` + // SoraMediaSignedURLTTLSeconds: 临时签名 URL 有效期(秒,<=0 表示禁用) + SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"` + // 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限) MaxAccountSwitches int `mapstructure:"max_account_switches"` // Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格) @@ -284,6 +455,194 @@ type GatewayConfig struct { // TLSFingerprint: TLS指纹伪装配置 TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"` + + // UsageRecord: 使用量记录异步队列配置(有界队列 + 固定 worker) + UsageRecord GatewayUsageRecordConfig `mapstructure:"usage_record"` + + // UserGroupRateCacheTTLSeconds: 用户分组倍率热路径缓存 TTL(秒) + UserGroupRateCacheTTLSeconds int `mapstructure:"user_group_rate_cache_ttl_seconds"` + // ModelsListCacheTTLSeconds: /v1/models 模型列表短缓存 TTL(秒) + ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"` + + // UserMessageQueue: 用户消息串行队列配置 + // 对 role:"user" 的真实用户消息实施账号级串行化 + RPM 自适应延迟 + UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"` +} + +// UserMessageQueueConfig 用户消息串行队列配置 +// 用于 Anthropic OAuth/SetupToken 账号的用户消息串行化发送 +type UserMessageQueueConfig struct { + // Mode: 模式选择 + // "serialize" = 账号级串行锁 + RPM 自适应延迟 + // "throttle" = 仅 RPM 自适应前置延迟,不阻塞并发 + // "" = 禁用(默认) + Mode string `mapstructure:"mode"` + // Enabled: 已废弃,仅向后兼容(等同于 mode: "serialize") + Enabled bool `mapstructure:"enabled"` + // LockTTLMs: 串行锁 TTL(毫秒),应大于最长请求时间 + LockTTLMs int `mapstructure:"lock_ttl_ms"` + // WaitTimeoutMs: 等待获取锁的超时时间(毫秒) + WaitTimeoutMs int `mapstructure:"wait_timeout_ms"` + // MinDelayMs: RPM 自适应延迟下限(毫秒) + MinDelayMs int `mapstructure:"min_delay_ms"` + // MaxDelayMs: RPM 自适应延迟上限(毫秒) + MaxDelayMs int `mapstructure:"max_delay_ms"` + // CleanupIntervalSeconds: 孤儿锁清理间隔(秒),0 表示禁用 + CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"` +} + +// WaitTimeout 返回等待超时的 time.Duration +func (c *UserMessageQueueConfig) WaitTimeout() time.Duration { + if c.WaitTimeoutMs <= 0 { + return 30 * time.Second + } + return time.Duration(c.WaitTimeoutMs) * time.Millisecond +} + +// GetEffectiveMode 返回生效的模式 +// 注意:Mode 字段已在 load() 中做过白名单校验和规范化,此处无需重复验证 +func (c *UserMessageQueueConfig) GetEffectiveMode() string { + if c.Mode == UMQModeSerialize || c.Mode == UMQModeThrottle { + return c.Mode + } + if c.Enabled { + return UMQModeSerialize // 向后兼容 + } + return "" +} + +// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。 +// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。 +type GatewayOpenAIWSConfig struct { + // ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为) + ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"` + // IngressModeDefault: ingress 默认模式(off/shared/dedicated) + IngressModeDefault string `mapstructure:"ingress_mode_default"` + // Enabled: 全局总开关(默认 true) + Enabled bool `mapstructure:"enabled"` + // OAuthEnabled: 是否允许 OpenAI OAuth 账号使用 WS + OAuthEnabled bool `mapstructure:"oauth_enabled"` + // APIKeyEnabled: 是否允许 OpenAI API Key 账号使用 WS + APIKeyEnabled bool `mapstructure:"apikey_enabled"` + // ForceHTTP: 全局强制 HTTP(用于紧急回滚) + ForceHTTP bool `mapstructure:"force_http"` + // AllowStoreRecovery: 允许在 WSv2 下按策略恢复 store=true(默认 false) + AllowStoreRecovery bool `mapstructure:"allow_store_recovery"` + // IngressPreviousResponseRecoveryEnabled: ingress 模式收到 previous_response_not_found 时,是否允许自动去掉 previous_response_id 重试一次(默认 true) + IngressPreviousResponseRecoveryEnabled bool `mapstructure:"ingress_previous_response_recovery_enabled"` + // StoreDisabledConnMode: store=false 且无可复用会话连接时的建连策略(strict/adaptive/off) + // - strict: 强制新建连接(隔离优先) + // - adaptive: 仅在高风险失败后强制新建连接(性能与隔离折中) + // - off: 不强制新建连接(复用优先) + StoreDisabledConnMode string `mapstructure:"store_disabled_conn_mode"` + // StoreDisabledForceNewConn: store=false 且无可复用粘连连接时是否强制新建连接(默认 true,保障会话隔离) + // 兼容旧配置;当 StoreDisabledConnMode 为空时才生效。 + StoreDisabledForceNewConn bool `mapstructure:"store_disabled_force_new_conn"` + // PrewarmGenerateEnabled: 是否启用 WSv2 generate=false 预热(默认 false) + PrewarmGenerateEnabled bool `mapstructure:"prewarm_generate_enabled"` + + // Feature 开关:v2 优先于 v1 + ResponsesWebsockets bool `mapstructure:"responses_websockets"` + ResponsesWebsocketsV2 bool `mapstructure:"responses_websockets_v2"` + + // 连接池参数 + MaxConnsPerAccount int `mapstructure:"max_conns_per_account"` + MinIdlePerAccount int `mapstructure:"min_idle_per_account"` + MaxIdlePerAccount int `mapstructure:"max_idle_per_account"` + // DynamicMaxConnsByAccountConcurrencyEnabled: 是否按账号并发动态计算连接池上限 + DynamicMaxConnsByAccountConcurrencyEnabled bool `mapstructure:"dynamic_max_conns_by_account_concurrency_enabled"` + // OAuthMaxConnsFactor: OAuth 账号连接池系数(effective=ceil(concurrency*factor)) + OAuthMaxConnsFactor float64 `mapstructure:"oauth_max_conns_factor"` + // APIKeyMaxConnsFactor: API Key 账号连接池系数(effective=ceil(concurrency*factor)) + APIKeyMaxConnsFactor float64 `mapstructure:"apikey_max_conns_factor"` + DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"` + ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"` + WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"` + PoolTargetUtilization float64 `mapstructure:"pool_target_utilization"` + QueueLimitPerConn int `mapstructure:"queue_limit_per_conn"` + // EventFlushBatchSize: WS 流式写出批量 flush 阈值(事件条数) + EventFlushBatchSize int `mapstructure:"event_flush_batch_size"` + // EventFlushIntervalMS: WS 流式写出最大等待时间(毫秒);0 表示仅按 batch 触发 + EventFlushIntervalMS int `mapstructure:"event_flush_interval_ms"` + // PrewarmCooldownMS: 连接池预热触发冷却时间(毫秒) + PrewarmCooldownMS int `mapstructure:"prewarm_cooldown_ms"` + // FallbackCooldownSeconds: WS 回退冷却窗口,避免 WS/HTTP 抖动;0 表示关闭冷却 + FallbackCooldownSeconds int `mapstructure:"fallback_cooldown_seconds"` + // RetryBackoffInitialMS: WS 重试初始退避(毫秒);<=0 表示关闭退避 + RetryBackoffInitialMS int `mapstructure:"retry_backoff_initial_ms"` + // RetryBackoffMaxMS: WS 重试最大退避(毫秒) + RetryBackoffMaxMS int `mapstructure:"retry_backoff_max_ms"` + // RetryJitterRatio: WS 重试退避抖动比例(0-1) + RetryJitterRatio float64 `mapstructure:"retry_jitter_ratio"` + // RetryTotalBudgetMS: WS 单次请求重试总预算(毫秒);0 表示关闭预算限制 + RetryTotalBudgetMS int `mapstructure:"retry_total_budget_ms"` + // PayloadLogSampleRate: payload_schema 日志采样率(0-1) + PayloadLogSampleRate float64 `mapstructure:"payload_log_sample_rate"` + + // 账号调度与粘连参数 + LBTopK int `mapstructure:"lb_top_k"` + // StickySessionTTLSeconds: session_hash -> account_id 粘连 TTL + StickySessionTTLSeconds int `mapstructure:"sticky_session_ttl_seconds"` + // SessionHashReadOldFallback: 会话哈希迁移期是否允许“新 key 未命中时回退读旧 SHA-256 key” + SessionHashReadOldFallback bool `mapstructure:"session_hash_read_old_fallback"` + // SessionHashDualWriteOld: 会话哈希迁移期是否双写旧 SHA-256 key(短 TTL) + SessionHashDualWriteOld bool `mapstructure:"session_hash_dual_write_old"` + // MetadataBridgeEnabled: RequestMetadata 迁移期是否保留旧 ctxkey.* 兼容桥接 + MetadataBridgeEnabled bool `mapstructure:"metadata_bridge_enabled"` + // StickyResponseIDTTLSeconds: response_id -> account_id 粘连 TTL + StickyResponseIDTTLSeconds int `mapstructure:"sticky_response_id_ttl_seconds"` + // StickyPreviousResponseTTLSeconds: 兼容旧键(当新键未设置时回退) + StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"` + + SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"` +} + +// GatewayOpenAIWSSchedulerScoreWeights 账号调度打分权重。 +type GatewayOpenAIWSSchedulerScoreWeights struct { + Priority float64 `mapstructure:"priority"` + Load float64 `mapstructure:"load"` + Queue float64 `mapstructure:"queue"` + ErrorRate float64 `mapstructure:"error_rate"` + TTFT float64 `mapstructure:"ttft"` +} + +// GatewayUsageRecordConfig 使用量记录异步队列配置 +type GatewayUsageRecordConfig struct { + // WorkerCount: worker 初始数量(自动扩缩容开启时作为初始并发上限) + WorkerCount int `mapstructure:"worker_count"` + // QueueSize: 队列容量(有界) + QueueSize int `mapstructure:"queue_size"` + // TaskTimeoutSeconds: 单个使用量记录任务超时(秒) + TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"` + // OverflowPolicy: 队列满时策略(drop/sample/sync) + OverflowPolicy string `mapstructure:"overflow_policy"` + // OverflowSamplePercent: sample 策略下,同步回写采样百分比(1-100) + OverflowSamplePercent int `mapstructure:"overflow_sample_percent"` + + // AutoScaleEnabled: 是否启用 worker 自动扩缩容 + AutoScaleEnabled bool `mapstructure:"auto_scale_enabled"` + // AutoScaleMinWorkers: 自动扩缩容最小 worker 数 + AutoScaleMinWorkers int `mapstructure:"auto_scale_min_workers"` + // AutoScaleMaxWorkers: 自动扩缩容最大 worker 数 + AutoScaleMaxWorkers int `mapstructure:"auto_scale_max_workers"` + // AutoScaleUpQueuePercent: 队列占用率达到该阈值时触发扩容 + AutoScaleUpQueuePercent int `mapstructure:"auto_scale_up_queue_percent"` + // AutoScaleDownQueuePercent: 队列占用率低于该阈值时触发缩容 + AutoScaleDownQueuePercent int `mapstructure:"auto_scale_down_queue_percent"` + // AutoScaleUpStep: 每次扩容步长 + AutoScaleUpStep int `mapstructure:"auto_scale_up_step"` + // AutoScaleDownStep: 每次缩容步长 + AutoScaleDownStep int `mapstructure:"auto_scale_down_step"` + // AutoScaleCheckIntervalSeconds: 自动扩缩容检测间隔(秒) + AutoScaleCheckIntervalSeconds int `mapstructure:"auto_scale_check_interval_seconds"` + // AutoScaleCooldownSeconds: 自动扩缩容冷却时间(秒) + AutoScaleCooldownSeconds int `mapstructure:"auto_scale_cooldown_seconds"` +} + +// SoraModelFiltersConfig Sora 模型过滤配置 +type SoraModelFiltersConfig struct { + // HidePromptEnhance 是否隐藏 prompt-enhance 模型 + HidePromptEnhance bool `mapstructure:"hide_prompt_enhance"` } // TLSFingerprintConfig TLS指纹伪装配置 @@ -479,8 +838,9 @@ type OpsMetricsCollectorCacheConfig struct { type JWTConfig struct { Secret string `mapstructure:"secret"` ExpireHour int `mapstructure:"expire_hour"` - // AccessTokenExpireMinutes: Access Token有效期(分钟),默认15分钟 - // 短有效期减少被盗用风险,配合Refresh Token实现无感续期 + // AccessTokenExpireMinutes: Access Token有效期(分钟) + // - >0: 使用分钟配置(优先级高于 ExpireHour) + // - =0: 回退使用 ExpireHour(向后兼容旧配置) AccessTokenExpireMinutes int `mapstructure:"access_token_expire_minutes"` // RefreshTokenExpireDays: Refresh Token有效期(天),默认30天 RefreshTokenExpireDays int `mapstructure:"refresh_token_expire_days"` @@ -512,7 +872,8 @@ type DefaultConfig struct { } type RateLimitConfig struct { - OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟) + OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟) + OAuth401CooldownMinutes int `mapstructure:"oauth_401_cooldown_minutes"` // OAuth 401临时不可调度冷却(分钟) } // APIKeyAuthCacheConfig API Key 认证缓存配置 @@ -525,6 +886,20 @@ type APIKeyAuthCacheConfig struct { Singleflight bool `mapstructure:"singleflight"` } +// SubscriptionCacheConfig 订阅认证 L1 缓存配置 +type SubscriptionCacheConfig struct { + L1Size int `mapstructure:"l1_size"` + L1TTLSeconds int `mapstructure:"l1_ttl_seconds"` + JitterPercent int `mapstructure:"jitter_percent"` +} + +// SubscriptionMaintenanceConfig 订阅窗口维护后台任务配置。 +// 用于将“请求路径触发的维护动作”有界化,避免高并发下 goroutine 膨胀。 +type SubscriptionMaintenanceConfig struct { + WorkerCount int `mapstructure:"worker_count"` + QueueSize int `mapstructure:"queue_size"` +} + // DashboardCacheConfig 仪表盘统计缓存配置 type DashboardCacheConfig struct { // Enabled: 是否启用仪表盘缓存 @@ -588,7 +963,19 @@ func NormalizeRunMode(value string) string { } } +// Load 读取并校验完整配置(要求 jwt.secret 已显式提供)。 func Load() (*Config, error) { + return load(false) +} + +// LoadForBootstrap 读取启动阶段配置。 +// +// 启动阶段允许 jwt.secret 先留空,后续由数据库初始化流程补齐并再次完整校验。 +func LoadForBootstrap() (*Config, error) { + return load(true) +} + +func load(allowMissingJWTSecret bool) (*Config, error) { viper.SetConfigName("config") viper.SetConfigType("yaml") @@ -630,6 +1017,7 @@ func Load() (*Config, error) { if cfg.Server.Mode == "" { cfg.Server.Mode = "debug" } + cfg.Server.FrontendURL = strings.TrimSpace(cfg.Server.FrontendURL) cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID) cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret) @@ -648,14 +1036,25 @@ func Load() (*Config, error) { cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed) cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove) cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy) + cfg.Log.Level = strings.ToLower(strings.TrimSpace(cfg.Log.Level)) + cfg.Log.Format = strings.ToLower(strings.TrimSpace(cfg.Log.Format)) + cfg.Log.ServiceName = strings.TrimSpace(cfg.Log.ServiceName) + cfg.Log.Environment = strings.TrimSpace(cfg.Log.Environment) + cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel)) + cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath) - if cfg.JWT.Secret == "" { - secret, err := generateJWTSecret(64) - if err != nil { - return nil, fmt.Errorf("generate jwt secret error: %w", err) - } - cfg.JWT.Secret = secret - log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.") + // 兼容旧键 gateway.openai_ws.sticky_previous_response_ttl_seconds。 + // 新键未配置(<=0)时回退旧键;新键优先。 + if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 { + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds + } + + // Normalize UMQ mode: 白名单校验,非法值在加载时一次性 warn 并清空 + if m := cfg.Gateway.UserMessageQueue.Mode; m != "" && m != UMQModeSerialize && m != UMQModeThrottle { + slog.Warn("invalid user_message_queue mode, disabling", + "mode", m, + "valid_modes", []string{UMQModeSerialize, UMQModeThrottle}) + cfg.Gateway.UserMessageQueue.Mode = "" } // Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256) @@ -667,29 +1066,39 @@ func Load() (*Config, error) { } cfg.Totp.EncryptionKey = key cfg.Totp.EncryptionKeyConfigured = false - log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.") + slog.Warn("TOTP encryption key auto-generated. Consider setting a fixed key for production.") } else { cfg.Totp.EncryptionKeyConfigured = true } + originalJWTSecret := cfg.JWT.Secret + if allowMissingJWTSecret && originalJWTSecret == "" { + // 启动阶段允许先无 JWT 密钥,后续在数据库初始化后补齐。 + cfg.JWT.Secret = strings.Repeat("0", 32) + } + if err := cfg.Validate(); err != nil { return nil, fmt.Errorf("validate config error: %w", err) } + if allowMissingJWTSecret && originalJWTSecret == "" { + cfg.JWT.Secret = "" + } + if !cfg.Security.URLAllowlist.Enabled { - log.Println("Warning: security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).") + slog.Warn("security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).") } if !cfg.Security.ResponseHeaders.Enabled { - log.Println("Warning: security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).") + slog.Warn("security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).") } if cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) { - log.Println("Warning: JWT secret appears weak; use a 32+ character random secret in production.") + slog.Warn("JWT secret appears weak; use a 32+ character random secret in production.") } if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 { - log.Printf("AUDIT: response header policy configured additional_allowed=%v force_remove=%v", - cfg.Security.ResponseHeaders.AdditionalAllowed, - cfg.Security.ResponseHeaders.ForceRemove, + slog.Info("response header policy configured", + "additional_allowed", cfg.Security.ResponseHeaders.AdditionalAllowed, + "force_remove", cfg.Security.ResponseHeaders.ForceRemove, ) } @@ -702,11 +1111,12 @@ func setDefaults() { // Server viper.SetDefault("server.host", "0.0.0.0") viper.SetDefault("server.port", 8080) - viper.SetDefault("server.mode", "debug") + viper.SetDefault("server.mode", "release") + viper.SetDefault("server.frontend_url", "") viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头 viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 viper.SetDefault("server.trusted_proxies", []string{}) - viper.SetDefault("server.max_request_body_size", int64(100*1024*1024)) + viper.SetDefault("server.max_request_body_size", int64(256*1024*1024)) // H2C 默认配置 viper.SetDefault("server.h2c.enabled", false) viper.SetDefault("server.h2c.max_concurrent_streams", uint32(50)) // 50 个并发流 @@ -715,6 +1125,25 @@ func setDefaults() { viper.SetDefault("server.h2c.max_upload_buffer_per_connection", 2<<20) // 2MB viper.SetDefault("server.h2c.max_upload_buffer_per_stream", 512<<10) // 512KB + // Log + viper.SetDefault("log.level", "info") + viper.SetDefault("log.format", "console") + viper.SetDefault("log.service_name", "sub2api") + viper.SetDefault("log.env", "production") + viper.SetDefault("log.caller", true) + viper.SetDefault("log.stacktrace_level", "error") + viper.SetDefault("log.output.to_stdout", true) + viper.SetDefault("log.output.to_file", true) + viper.SetDefault("log.output.file_path", "") + viper.SetDefault("log.rotation.max_size_mb", 100) + viper.SetDefault("log.rotation.max_backups", 10) + viper.SetDefault("log.rotation.max_age_days", 7) + viper.SetDefault("log.rotation.compress", true) + viper.SetDefault("log.rotation.local_time", true) + viper.SetDefault("log.sampling.enabled", false) + viper.SetDefault("log.sampling.initial", 100) + viper.SetDefault("log.sampling.thereafter", 100) + // CORS viper.SetDefault("cors.allowed_origins", []string{}) viper.SetDefault("cors.allow_credentials", true) @@ -737,13 +1166,16 @@ func setDefaults() { viper.SetDefault("security.url_allowlist.crs_hosts", []string{}) viper.SetDefault("security.url_allowlist.allow_private_hosts", true) viper.SetDefault("security.url_allowlist.allow_insecure_http", true) - viper.SetDefault("security.response_headers.enabled", false) + viper.SetDefault("security.response_headers.enabled", true) viper.SetDefault("security.response_headers.additional_allowed", []string{}) viper.SetDefault("security.response_headers.force_remove", []string{}) viper.SetDefault("security.csp.enabled", true) viper.SetDefault("security.csp.policy", DefaultCSPPolicy) viper.SetDefault("security.proxy_probe.insecure_skip_verify", false) + // Security - disable direct fallback on proxy error + viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false) + // Billing viper.SetDefault("billing.circuit_breaker.enabled", true) viper.SetDefault("billing.circuit_breaker.failure_threshold", 5) @@ -775,9 +1207,9 @@ func setDefaults() { viper.SetDefault("database.user", "postgres") viper.SetDefault("database.password", "postgres") viper.SetDefault("database.dbname", "sub2api") - viper.SetDefault("database.sslmode", "disable") - viper.SetDefault("database.max_open_conns", 50) - viper.SetDefault("database.max_idle_conns", 10) + viper.SetDefault("database.sslmode", "prefer") + viper.SetDefault("database.max_open_conns", 256) + viper.SetDefault("database.max_idle_conns", 128) viper.SetDefault("database.conn_max_lifetime_minutes", 30) viper.SetDefault("database.conn_max_idle_time_minutes", 5) @@ -789,13 +1221,13 @@ func setDefaults() { viper.SetDefault("redis.dial_timeout_seconds", 5) viper.SetDefault("redis.read_timeout_seconds", 3) viper.SetDefault("redis.write_timeout_seconds", 3) - viper.SetDefault("redis.pool_size", 128) - viper.SetDefault("redis.min_idle_conns", 10) + viper.SetDefault("redis.pool_size", 1024) + viper.SetDefault("redis.min_idle_conns", 128) viper.SetDefault("redis.enable_tls", false) // Ops (vNext) viper.SetDefault("ops.enabled", true) - viper.SetDefault("ops.use_preaggregated_tables", false) + viper.SetDefault("ops.use_preaggregated_tables", true) viper.SetDefault("ops.cleanup.enabled", true) viper.SetDefault("ops.cleanup.schedule", "0 2 * * *") // Retention days: vNext defaults to 30 days across ops datasets. @@ -810,9 +1242,9 @@ func setDefaults() { // JWT viper.SetDefault("jwt.secret", "") viper.SetDefault("jwt.expire_hour", 24) - viper.SetDefault("jwt.access_token_expire_minutes", 360) // 6小时Access Token有效期 - viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期 - viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新 + viper.SetDefault("jwt.access_token_expire_minutes", 0) // 0 表示回退到 expire_hour + viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期 + viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新 // TOTP viper.SetDefault("totp.encryption_key", "") @@ -829,10 +1261,11 @@ func setDefaults() { // RateLimit viper.SetDefault("rate_limit.overload_cooldown_minutes", 10) + viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 10) - // Pricing - 从 price-mirror 分支同步,该分支维护了 sha256 哈希文件用于增量更新检查 - viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.json") - viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/claude-relay-service/price-mirror/model_prices_and_context_window.sha256") + // Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移) + viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json") + viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256") viper.SetDefault("pricing.data_dir", "./data") viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json") viper.SetDefault("pricing.update_interval_hours", 24) @@ -849,6 +1282,11 @@ func setDefaults() { viper.SetDefault("api_key_auth_cache.jitter_percent", 10) viper.SetDefault("api_key_auth_cache.singleflight", true) + // Subscription auth L1 cache + viper.SetDefault("subscription_cache.l1_size", 16384) + viper.SetDefault("subscription_cache.l1_ttl_seconds", 10) + viper.SetDefault("subscription_cache.jitter_percent", 10) + // Dashboard cache viper.SetDefault("dashboard_cache.enabled", true) viper.SetDefault("dashboard_cache.key_prefix", "sub2api:") @@ -874,6 +1312,16 @@ func setDefaults() { viper.SetDefault("usage_cleanup.worker_interval_seconds", 10) viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800) + // Idempotency + viper.SetDefault("idempotency.observe_only", true) + viper.SetDefault("idempotency.default_ttl_seconds", 86400) + viper.SetDefault("idempotency.system_operation_ttl_seconds", 3600) + viper.SetDefault("idempotency.processing_timeout_seconds", 30) + viper.SetDefault("idempotency.failed_retry_backoff_seconds", 5) + viper.SetDefault("idempotency.max_stored_response_len", 64*1024) + viper.SetDefault("idempotency.cleanup_interval_seconds", 60) + viper.SetDefault("idempotency.cleanup_batch_size", 500) + // Gateway viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久 viper.SetDefault("gateway.log_upstream_error_body", true) @@ -882,13 +1330,72 @@ func setDefaults() { viper.SetDefault("gateway.failover_on_400", false) viper.SetDefault("gateway.max_account_switches", 10) viper.SetDefault("gateway.max_account_switches_gemini", 3) + viper.SetDefault("gateway.force_codex_cli", false) + viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false) + // OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚) + viper.SetDefault("gateway.openai_ws.enabled", true) + viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false) + viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared") + viper.SetDefault("gateway.openai_ws.oauth_enabled", true) + viper.SetDefault("gateway.openai_ws.apikey_enabled", true) + viper.SetDefault("gateway.openai_ws.force_http", false) + viper.SetDefault("gateway.openai_ws.allow_store_recovery", false) + viper.SetDefault("gateway.openai_ws.ingress_previous_response_recovery_enabled", true) + viper.SetDefault("gateway.openai_ws.store_disabled_conn_mode", "strict") + viper.SetDefault("gateway.openai_ws.store_disabled_force_new_conn", true) + viper.SetDefault("gateway.openai_ws.prewarm_generate_enabled", false) + viper.SetDefault("gateway.openai_ws.responses_websockets", false) + viper.SetDefault("gateway.openai_ws.responses_websockets_v2", true) + viper.SetDefault("gateway.openai_ws.max_conns_per_account", 128) + viper.SetDefault("gateway.openai_ws.min_idle_per_account", 4) + viper.SetDefault("gateway.openai_ws.max_idle_per_account", 12) + viper.SetDefault("gateway.openai_ws.dynamic_max_conns_by_account_concurrency_enabled", true) + viper.SetDefault("gateway.openai_ws.oauth_max_conns_factor", 1.0) + viper.SetDefault("gateway.openai_ws.apikey_max_conns_factor", 1.0) + viper.SetDefault("gateway.openai_ws.dial_timeout_seconds", 10) + viper.SetDefault("gateway.openai_ws.read_timeout_seconds", 900) + viper.SetDefault("gateway.openai_ws.write_timeout_seconds", 120) + viper.SetDefault("gateway.openai_ws.pool_target_utilization", 0.7) + viper.SetDefault("gateway.openai_ws.queue_limit_per_conn", 64) + viper.SetDefault("gateway.openai_ws.event_flush_batch_size", 1) + viper.SetDefault("gateway.openai_ws.event_flush_interval_ms", 10) + viper.SetDefault("gateway.openai_ws.prewarm_cooldown_ms", 300) + viper.SetDefault("gateway.openai_ws.fallback_cooldown_seconds", 30) + viper.SetDefault("gateway.openai_ws.retry_backoff_initial_ms", 120) + viper.SetDefault("gateway.openai_ws.retry_backoff_max_ms", 2000) + viper.SetDefault("gateway.openai_ws.retry_jitter_ratio", 0.2) + viper.SetDefault("gateway.openai_ws.retry_total_budget_ms", 5000) + viper.SetDefault("gateway.openai_ws.payload_log_sample_rate", 0.2) + viper.SetDefault("gateway.openai_ws.lb_top_k", 7) + viper.SetDefault("gateway.openai_ws.sticky_session_ttl_seconds", 3600) + viper.SetDefault("gateway.openai_ws.session_hash_read_old_fallback", true) + viper.SetDefault("gateway.openai_ws.session_hash_dual_write_old", true) + viper.SetDefault("gateway.openai_ws.metadata_bridge_enabled", true) + viper.SetDefault("gateway.openai_ws.sticky_response_id_ttl_seconds", 3600) + viper.SetDefault("gateway.openai_ws.sticky_previous_response_ttl_seconds", 3600) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.priority", 1.0) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.load", 1.0) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) - viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) + viper.SetDefault("gateway.antigravity_extra_retries", 10) + viper.SetDefault("gateway.max_body_size", int64(256*1024*1024)) + viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024)) + viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024)) + viper.SetDefault("gateway.gemini_debug_response_headers", false) + viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024)) + viper.SetDefault("gateway.sora_stream_timeout_seconds", 900) + viper.SetDefault("gateway.sora_request_timeout_seconds", 180) + viper.SetDefault("gateway.sora_stream_mode", "force") + viper.SetDefault("gateway.sora_model_filters.hide_prompt_enhance", true) + viper.SetDefault("gateway.sora_media_require_api_key", true) + viper.SetDefault("gateway.sora_media_signed_url_ttl_seconds", 900) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) // HTTP 上游连接池配置(针对 5000+ 并发用户优化) - viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认) + viper.SetDefault("gateway.max_idle_conns", 2560) // 最大空闲连接总数(高并发场景可调大) viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认) - viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认) + viper.SetDefault("gateway.max_conns_per_host", 1024) // 每主机最大连接数(含活跃;流式/HTTP1.1 场景可调大,如 2400+) viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒) viper.SetDefault("gateway.max_upstream_clients", 5000) viper.SetDefault("gateway.client_idle_ttl_seconds", 900) @@ -912,16 +1419,73 @@ func setDefaults() { viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3) viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000) viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300) + viper.SetDefault("gateway.usage_record.worker_count", 128) + viper.SetDefault("gateway.usage_record.queue_size", 16384) + viper.SetDefault("gateway.usage_record.task_timeout_seconds", 5) + viper.SetDefault("gateway.usage_record.overflow_policy", UsageRecordOverflowPolicySample) + viper.SetDefault("gateway.usage_record.overflow_sample_percent", 10) + viper.SetDefault("gateway.usage_record.auto_scale_enabled", true) + viper.SetDefault("gateway.usage_record.auto_scale_min_workers", 128) + viper.SetDefault("gateway.usage_record.auto_scale_max_workers", 512) + viper.SetDefault("gateway.usage_record.auto_scale_up_queue_percent", 70) + viper.SetDefault("gateway.usage_record.auto_scale_down_queue_percent", 15) + viper.SetDefault("gateway.usage_record.auto_scale_up_step", 32) + viper.SetDefault("gateway.usage_record.auto_scale_down_step", 16) + viper.SetDefault("gateway.usage_record.auto_scale_check_interval_seconds", 3) + viper.SetDefault("gateway.usage_record.auto_scale_cooldown_seconds", 10) + viper.SetDefault("gateway.user_group_rate_cache_ttl_seconds", 30) + viper.SetDefault("gateway.models_list_cache_ttl_seconds", 15) // TLS指纹伪装配置(默认关闭,需要账号级别单独启用) + // 用户消息串行队列默认值 + viper.SetDefault("gateway.user_message_queue.enabled", false) + viper.SetDefault("gateway.user_message_queue.lock_ttl_ms", 120000) + viper.SetDefault("gateway.user_message_queue.wait_timeout_ms", 30000) + viper.SetDefault("gateway.user_message_queue.min_delay_ms", 200) + viper.SetDefault("gateway.user_message_queue.max_delay_ms", 2000) + viper.SetDefault("gateway.user_message_queue.cleanup_interval_seconds", 60) + viper.SetDefault("gateway.tls_fingerprint.enabled", true) viper.SetDefault("concurrency.ping_interval", 10) + // Sora 直连配置 + viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend") + viper.SetDefault("sora.client.timeout_seconds", 120) + viper.SetDefault("sora.client.max_retries", 3) + viper.SetDefault("sora.client.cloudflare_challenge_cooldown_seconds", 900) + viper.SetDefault("sora.client.poll_interval_seconds", 2) + viper.SetDefault("sora.client.max_poll_attempts", 600) + viper.SetDefault("sora.client.recent_task_limit", 50) + viper.SetDefault("sora.client.recent_task_limit_max", 200) + viper.SetDefault("sora.client.debug", false) + viper.SetDefault("sora.client.use_openai_token_provider", false) + viper.SetDefault("sora.client.headers", map[string]string{}) + viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + viper.SetDefault("sora.client.disable_tls_fingerprint", false) + viper.SetDefault("sora.client.curl_cffi_sidecar.enabled", true) + viper.SetDefault("sora.client.curl_cffi_sidecar.base_url", "http://sora-curl-cffi-sidecar:8080") + viper.SetDefault("sora.client.curl_cffi_sidecar.impersonate", "chrome131") + viper.SetDefault("sora.client.curl_cffi_sidecar.timeout_seconds", 60) + viper.SetDefault("sora.client.curl_cffi_sidecar.session_reuse_enabled", true) + viper.SetDefault("sora.client.curl_cffi_sidecar.session_ttl_seconds", 3600) + + viper.SetDefault("sora.storage.type", "local") + viper.SetDefault("sora.storage.local_path", "") + viper.SetDefault("sora.storage.fallback_to_upstream", true) + viper.SetDefault("sora.storage.max_concurrent_downloads", 4) + viper.SetDefault("sora.storage.download_timeout_seconds", 120) + viper.SetDefault("sora.storage.max_download_bytes", int64(200<<20)) + viper.SetDefault("sora.storage.debug", false) + viper.SetDefault("sora.storage.cleanup.enabled", true) + viper.SetDefault("sora.storage.cleanup.retention_days", 7) + viper.SetDefault("sora.storage.cleanup.schedule", "0 3 * * *") + // TokenRefresh viper.SetDefault("token_refresh.enabled", true) viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次 viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token) viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 + viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token // Gemini OAuth - configure via environment variables or config file // GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET @@ -930,9 +1494,103 @@ func setDefaults() { viper.SetDefault("gemini.oauth.client_secret", "") viper.SetDefault("gemini.oauth.scopes", "") viper.SetDefault("gemini.quota.policy", "") + + // Subscription Maintenance (bounded queue + worker pool) + viper.SetDefault("subscription_maintenance.worker_count", 2) + viper.SetDefault("subscription_maintenance.queue_size", 1024) + } func (c *Config) Validate() error { + jwtSecret := strings.TrimSpace(c.JWT.Secret) + if jwtSecret == "" { + return fmt.Errorf("jwt.secret is required") + } + // NOTE: 按 UTF-8 编码后的字节长度计算。 + // 选择 bytes 而不是 rune 计数,确保二进制/随机串的长度语义更接近“熵”而非“字符数”。 + if len([]byte(jwtSecret)) < 32 { + return fmt.Errorf("jwt.secret must be at least 32 bytes") + } + switch c.Log.Level { + case "debug", "info", "warn", "error": + case "": + return fmt.Errorf("log.level is required") + default: + return fmt.Errorf("log.level must be one of: debug/info/warn/error") + } + switch c.Log.Format { + case "json", "console": + case "": + return fmt.Errorf("log.format is required") + default: + return fmt.Errorf("log.format must be one of: json/console") + } + switch c.Log.StacktraceLevel { + case "none", "error", "fatal": + case "": + return fmt.Errorf("log.stacktrace_level is required") + default: + return fmt.Errorf("log.stacktrace_level must be one of: none/error/fatal") + } + if !c.Log.Output.ToStdout && !c.Log.Output.ToFile { + return fmt.Errorf("log.output.to_stdout and log.output.to_file cannot both be false") + } + if c.Log.Rotation.MaxSizeMB <= 0 { + return fmt.Errorf("log.rotation.max_size_mb must be positive") + } + if c.Log.Rotation.MaxBackups < 0 { + return fmt.Errorf("log.rotation.max_backups must be non-negative") + } + if c.Log.Rotation.MaxAgeDays < 0 { + return fmt.Errorf("log.rotation.max_age_days must be non-negative") + } + if c.Log.Sampling.Enabled { + if c.Log.Sampling.Initial <= 0 { + return fmt.Errorf("log.sampling.initial must be positive when sampling is enabled") + } + if c.Log.Sampling.Thereafter <= 0 { + return fmt.Errorf("log.sampling.thereafter must be positive when sampling is enabled") + } + } else { + if c.Log.Sampling.Initial < 0 { + return fmt.Errorf("log.sampling.initial must be non-negative") + } + if c.Log.Sampling.Thereafter < 0 { + return fmt.Errorf("log.sampling.thereafter must be non-negative") + } + } + + if c.SubscriptionMaintenance.WorkerCount < 0 { + return fmt.Errorf("subscription_maintenance.worker_count must be non-negative") + } + if c.SubscriptionMaintenance.QueueSize < 0 { + return fmt.Errorf("subscription_maintenance.queue_size must be non-negative") + } + + // Gemini OAuth 配置校验:client_id 与 client_secret 必须同时设置或同时留空。 + // 留空时表示使用内置的 Gemini CLI OAuth 客户端(其 client_secret 通过环境变量注入)。 + geminiClientID := strings.TrimSpace(c.Gemini.OAuth.ClientID) + geminiClientSecret := strings.TrimSpace(c.Gemini.OAuth.ClientSecret) + if (geminiClientID == "") != (geminiClientSecret == "") { + return fmt.Errorf("gemini.oauth.client_id and gemini.oauth.client_secret must be both set or both empty") + } + + if strings.TrimSpace(c.Server.FrontendURL) != "" { + if err := ValidateAbsoluteHTTPURL(c.Server.FrontendURL); err != nil { + return fmt.Errorf("server.frontend_url invalid: %w", err) + } + u, err := url.Parse(strings.TrimSpace(c.Server.FrontendURL)) + if err != nil { + return fmt.Errorf("server.frontend_url invalid: %w", err) + } + if u.RawQuery != "" || u.ForceQuery { + return fmt.Errorf("server.frontend_url invalid: must not include query") + } + if u.User != nil { + return fmt.Errorf("server.frontend_url invalid: must not include userinfo") + } + warnIfInsecureURL("server.frontend_url", c.Server.FrontendURL) + } if c.JWT.ExpireHour <= 0 { return fmt.Errorf("jwt.expire_hour must be positive") } @@ -940,20 +1598,20 @@ func (c *Config) Validate() error { return fmt.Errorf("jwt.expire_hour must be <= 168 (7 days)") } if c.JWT.ExpireHour > 24 { - log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour) + slog.Warn("jwt.expire_hour is high; consider shorter expiration for security", "expire_hour", c.JWT.ExpireHour) } // JWT Refresh Token配置验证 - if c.JWT.AccessTokenExpireMinutes <= 0 { - return fmt.Errorf("jwt.access_token_expire_minutes must be positive") + if c.JWT.AccessTokenExpireMinutes < 0 { + return fmt.Errorf("jwt.access_token_expire_minutes must be non-negative") } if c.JWT.AccessTokenExpireMinutes > 720 { - log.Printf("Warning: jwt.access_token_expire_minutes is %d (> 720). Consider shorter expiration for security.", c.JWT.AccessTokenExpireMinutes) + slog.Warn("jwt.access_token_expire_minutes is high; consider shorter expiration for security", "access_token_expire_minutes", c.JWT.AccessTokenExpireMinutes) } if c.JWT.RefreshTokenExpireDays <= 0 { return fmt.Errorf("jwt.refresh_token_expire_days must be positive") } if c.JWT.RefreshTokenExpireDays > 90 { - log.Printf("Warning: jwt.refresh_token_expire_days is %d (> 90). Consider shorter expiration for security.", c.JWT.RefreshTokenExpireDays) + slog.Warn("jwt.refresh_token_expire_days is high; consider shorter expiration for security", "refresh_token_expire_days", c.JWT.RefreshTokenExpireDays) } if c.JWT.RefreshWindowMinutes < 0 { return fmt.Errorf("jwt.refresh_window_minutes must be non-negative") @@ -1159,9 +1817,116 @@ func (c *Config) Validate() error { return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative") } } + if c.Idempotency.DefaultTTLSeconds <= 0 { + return fmt.Errorf("idempotency.default_ttl_seconds must be positive") + } + if c.Idempotency.SystemOperationTTLSeconds <= 0 { + return fmt.Errorf("idempotency.system_operation_ttl_seconds must be positive") + } + if c.Idempotency.ProcessingTimeoutSeconds <= 0 { + return fmt.Errorf("idempotency.processing_timeout_seconds must be positive") + } + if c.Idempotency.FailedRetryBackoffSeconds <= 0 { + return fmt.Errorf("idempotency.failed_retry_backoff_seconds must be positive") + } + if c.Idempotency.MaxStoredResponseLen <= 0 { + return fmt.Errorf("idempotency.max_stored_response_len must be positive") + } + if c.Idempotency.CleanupIntervalSeconds <= 0 { + return fmt.Errorf("idempotency.cleanup_interval_seconds must be positive") + } + if c.Idempotency.CleanupBatchSize <= 0 { + return fmt.Errorf("idempotency.cleanup_batch_size must be positive") + } if c.Gateway.MaxBodySize <= 0 { return fmt.Errorf("gateway.max_body_size must be positive") } + if c.Gateway.UpstreamResponseReadMaxBytes <= 0 { + return fmt.Errorf("gateway.upstream_response_read_max_bytes must be positive") + } + if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 { + return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive") + } + if c.Gateway.SoraMaxBodySize < 0 { + return fmt.Errorf("gateway.sora_max_body_size must be non-negative") + } + if c.Gateway.SoraStreamTimeoutSeconds < 0 { + return fmt.Errorf("gateway.sora_stream_timeout_seconds must be non-negative") + } + if c.Gateway.SoraRequestTimeoutSeconds < 0 { + return fmt.Errorf("gateway.sora_request_timeout_seconds must be non-negative") + } + if c.Gateway.SoraMediaSignedURLTTLSeconds < 0 { + return fmt.Errorf("gateway.sora_media_signed_url_ttl_seconds must be non-negative") + } + if mode := strings.TrimSpace(strings.ToLower(c.Gateway.SoraStreamMode)); mode != "" { + switch mode { + case "force", "error": + default: + return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error") + } + } + if c.Sora.Client.TimeoutSeconds < 0 { + return fmt.Errorf("sora.client.timeout_seconds must be non-negative") + } + if c.Sora.Client.MaxRetries < 0 { + return fmt.Errorf("sora.client.max_retries must be non-negative") + } + if c.Sora.Client.CloudflareChallengeCooldownSeconds < 0 { + return fmt.Errorf("sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") + } + if c.Sora.Client.PollIntervalSeconds < 0 { + return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative") + } + if c.Sora.Client.MaxPollAttempts < 0 { + return fmt.Errorf("sora.client.max_poll_attempts must be non-negative") + } + if c.Sora.Client.RecentTaskLimit < 0 { + return fmt.Errorf("sora.client.recent_task_limit must be non-negative") + } + if c.Sora.Client.RecentTaskLimitMax < 0 { + return fmt.Errorf("sora.client.recent_task_limit_max must be non-negative") + } + if c.Sora.Client.RecentTaskLimitMax > 0 && c.Sora.Client.RecentTaskLimit > 0 && + c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit { + c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit + } + if c.Sora.Client.CurlCFFISidecar.TimeoutSeconds < 0 { + return fmt.Errorf("sora.client.curl_cffi_sidecar.timeout_seconds must be non-negative") + } + if c.Sora.Client.CurlCFFISidecar.SessionTTLSeconds < 0 { + return fmt.Errorf("sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") + } + if !c.Sora.Client.CurlCFFISidecar.Enabled { + return fmt.Errorf("sora.client.curl_cffi_sidecar.enabled must be true") + } + if strings.TrimSpace(c.Sora.Client.CurlCFFISidecar.BaseURL) == "" { + return fmt.Errorf("sora.client.curl_cffi_sidecar.base_url is required") + } + if c.Sora.Storage.MaxConcurrentDownloads < 0 { + return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative") + } + if c.Sora.Storage.DownloadTimeoutSeconds < 0 { + return fmt.Errorf("sora.storage.download_timeout_seconds must be non-negative") + } + if c.Sora.Storage.MaxDownloadBytes < 0 { + return fmt.Errorf("sora.storage.max_download_bytes must be non-negative") + } + if c.Sora.Storage.Cleanup.Enabled { + if c.Sora.Storage.Cleanup.RetentionDays <= 0 { + return fmt.Errorf("sora.storage.cleanup.retention_days must be positive") + } + if strings.TrimSpace(c.Sora.Storage.Cleanup.Schedule) == "" { + return fmt.Errorf("sora.storage.cleanup.schedule is required when cleanup is enabled") + } + } else { + if c.Sora.Storage.Cleanup.RetentionDays < 0 { + return fmt.Errorf("sora.storage.cleanup.retention_days must be non-negative") + } + } + if storageType := strings.TrimSpace(strings.ToLower(c.Sora.Storage.Type)); storageType != "" && storageType != "local" { + return fmt.Errorf("sora.storage.type must be 'local'") + } if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" { switch c.Gateway.ConnectionPoolIsolation { case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy: @@ -1183,7 +1948,7 @@ func (c *Config) Validate() error { return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive") } if c.Gateway.IdleConnTimeoutSeconds > 180 { - log.Printf("Warning: gateway.idle_conn_timeout_seconds is %d (> 180). Consider 60-120 seconds for better connection reuse.", c.Gateway.IdleConnTimeoutSeconds) + slog.Warn("gateway.idle_conn_timeout_seconds is high; consider 60-120 seconds for better connection reuse", "idle_conn_timeout_seconds", c.Gateway.IdleConnTimeoutSeconds) } if c.Gateway.MaxUpstreamClients <= 0 { return fmt.Errorf("gateway.max_upstream_clients must be positive") @@ -1208,12 +1973,188 @@ func (c *Config) Validate() error { (c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) { return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds") } + // 兼容旧键 sticky_previous_response_ttl_seconds + if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 { + c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds + } + if c.Gateway.OpenAIWS.MaxConnsPerAccount <= 0 { + return fmt.Errorf("gateway.openai_ws.max_conns_per_account must be positive") + } + if c.Gateway.OpenAIWS.MinIdlePerAccount < 0 { + return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be non-negative") + } + if c.Gateway.OpenAIWS.MaxIdlePerAccount < 0 { + return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be non-negative") + } + if c.Gateway.OpenAIWS.MinIdlePerAccount > c.Gateway.OpenAIWS.MaxIdlePerAccount { + return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be <= max_idle_per_account") + } + if c.Gateway.OpenAIWS.MaxIdlePerAccount > c.Gateway.OpenAIWS.MaxConnsPerAccount { + return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be <= max_conns_per_account") + } + if c.Gateway.OpenAIWS.OAuthMaxConnsFactor <= 0 { + return fmt.Errorf("gateway.openai_ws.oauth_max_conns_factor must be positive") + } + if c.Gateway.OpenAIWS.APIKeyMaxConnsFactor <= 0 { + return fmt.Errorf("gateway.openai_ws.apikey_max_conns_factor must be positive") + } + if c.Gateway.OpenAIWS.DialTimeoutSeconds <= 0 { + return fmt.Errorf("gateway.openai_ws.dial_timeout_seconds must be positive") + } + if c.Gateway.OpenAIWS.ReadTimeoutSeconds <= 0 { + return fmt.Errorf("gateway.openai_ws.read_timeout_seconds must be positive") + } + if c.Gateway.OpenAIWS.WriteTimeoutSeconds <= 0 { + return fmt.Errorf("gateway.openai_ws.write_timeout_seconds must be positive") + } + if c.Gateway.OpenAIWS.PoolTargetUtilization <= 0 || c.Gateway.OpenAIWS.PoolTargetUtilization > 1 { + return fmt.Errorf("gateway.openai_ws.pool_target_utilization must be within (0,1]") + } + if c.Gateway.OpenAIWS.QueueLimitPerConn <= 0 { + return fmt.Errorf("gateway.openai_ws.queue_limit_per_conn must be positive") + } + if c.Gateway.OpenAIWS.EventFlushBatchSize <= 0 { + return fmt.Errorf("gateway.openai_ws.event_flush_batch_size must be positive") + } + if c.Gateway.OpenAIWS.EventFlushIntervalMS < 0 { + return fmt.Errorf("gateway.openai_ws.event_flush_interval_ms must be non-negative") + } + if c.Gateway.OpenAIWS.PrewarmCooldownMS < 0 { + return fmt.Errorf("gateway.openai_ws.prewarm_cooldown_ms must be non-negative") + } + if c.Gateway.OpenAIWS.FallbackCooldownSeconds < 0 { + return fmt.Errorf("gateway.openai_ws.fallback_cooldown_seconds must be non-negative") + } + if c.Gateway.OpenAIWS.RetryBackoffInitialMS < 0 { + return fmt.Errorf("gateway.openai_ws.retry_backoff_initial_ms must be non-negative") + } + if c.Gateway.OpenAIWS.RetryBackoffMaxMS < 0 { + return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be non-negative") + } + if c.Gateway.OpenAIWS.RetryBackoffInitialMS > 0 && c.Gateway.OpenAIWS.RetryBackoffMaxMS > 0 && + c.Gateway.OpenAIWS.RetryBackoffMaxMS < c.Gateway.OpenAIWS.RetryBackoffInitialMS { + return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be >= retry_backoff_initial_ms") + } + if c.Gateway.OpenAIWS.RetryJitterRatio < 0 || c.Gateway.OpenAIWS.RetryJitterRatio > 1 { + return fmt.Errorf("gateway.openai_ws.retry_jitter_ratio must be within [0,1]") + } + if c.Gateway.OpenAIWS.RetryTotalBudgetMS < 0 { + return fmt.Errorf("gateway.openai_ws.retry_total_budget_ms must be non-negative") + } + if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" { + switch mode { + case "off", "shared", "dedicated": + default: + return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated") + } + } + if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" { + switch mode { + case "strict", "adaptive", "off": + default: + return fmt.Errorf("gateway.openai_ws.store_disabled_conn_mode must be one of strict|adaptive|off") + } + } + if c.Gateway.OpenAIWS.PayloadLogSampleRate < 0 || c.Gateway.OpenAIWS.PayloadLogSampleRate > 1 { + return fmt.Errorf("gateway.openai_ws.payload_log_sample_rate must be within [0,1]") + } + if c.Gateway.OpenAIWS.LBTopK <= 0 { + return fmt.Errorf("gateway.openai_ws.lb_top_k must be positive") + } + if c.Gateway.OpenAIWS.StickySessionTTLSeconds <= 0 { + return fmt.Errorf("gateway.openai_ws.sticky_session_ttl_seconds must be positive") + } + if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 { + return fmt.Errorf("gateway.openai_ws.sticky_response_id_ttl_seconds must be positive") + } + if c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds < 0 { + return fmt.Errorf("gateway.openai_ws.sticky_previous_response_ttl_seconds must be non-negative") + } + if c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority < 0 || + c.Gateway.OpenAIWS.SchedulerScoreWeights.Load < 0 || + c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue < 0 || + c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate < 0 || + c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT < 0 { + return fmt.Errorf("gateway.openai_ws.scheduler_score_weights.* must be non-negative") + } + weightSum := c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority + + c.Gateway.OpenAIWS.SchedulerScoreWeights.Load + + c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue + + c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate + + c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT + if weightSum <= 0 { + return fmt.Errorf("gateway.openai_ws.scheduler_score_weights must not all be zero") + } if c.Gateway.MaxLineSize < 0 { return fmt.Errorf("gateway.max_line_size must be non-negative") } if c.Gateway.MaxLineSize != 0 && c.Gateway.MaxLineSize < 1024*1024 { return fmt.Errorf("gateway.max_line_size must be at least 1MB") } + if c.Gateway.UsageRecord.WorkerCount <= 0 { + return fmt.Errorf("gateway.usage_record.worker_count must be positive") + } + if c.Gateway.UsageRecord.QueueSize <= 0 { + return fmt.Errorf("gateway.usage_record.queue_size must be positive") + } + if c.Gateway.UsageRecord.TaskTimeoutSeconds <= 0 { + return fmt.Errorf("gateway.usage_record.task_timeout_seconds must be positive") + } + switch strings.ToLower(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy)) { + case UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync: + default: + return fmt.Errorf("gateway.usage_record.overflow_policy must be one of: %s/%s/%s", + UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync) + } + if c.Gateway.UsageRecord.OverflowSamplePercent < 0 || c.Gateway.UsageRecord.OverflowSamplePercent > 100 { + return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be between 0-100") + } + if strings.EqualFold(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy), UsageRecordOverflowPolicySample) && + c.Gateway.UsageRecord.OverflowSamplePercent <= 0 { + return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be positive when overflow_policy=sample") + } + if c.Gateway.UsageRecord.AutoScaleEnabled { + if c.Gateway.UsageRecord.AutoScaleMinWorkers <= 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_min_workers must be positive") + } + if c.Gateway.UsageRecord.AutoScaleMaxWorkers <= 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be positive") + } + if c.Gateway.UsageRecord.AutoScaleMaxWorkers < c.Gateway.UsageRecord.AutoScaleMinWorkers { + return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be >= auto_scale_min_workers") + } + if c.Gateway.UsageRecord.WorkerCount < c.Gateway.UsageRecord.AutoScaleMinWorkers || + c.Gateway.UsageRecord.WorkerCount > c.Gateway.UsageRecord.AutoScaleMaxWorkers { + return fmt.Errorf("gateway.usage_record.worker_count must be between auto_scale_min_workers and auto_scale_max_workers") + } + if c.Gateway.UsageRecord.AutoScaleUpQueuePercent <= 0 || c.Gateway.UsageRecord.AutoScaleUpQueuePercent > 100 { + return fmt.Errorf("gateway.usage_record.auto_scale_up_queue_percent must be between 1-100") + } + if c.Gateway.UsageRecord.AutoScaleDownQueuePercent < 0 || c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= 100 { + return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be between 0-99") + } + if c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= c.Gateway.UsageRecord.AutoScaleUpQueuePercent { + return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be less than auto_scale_up_queue_percent") + } + if c.Gateway.UsageRecord.AutoScaleUpStep <= 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_up_step must be positive") + } + if c.Gateway.UsageRecord.AutoScaleDownStep <= 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_down_step must be positive") + } + if c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds <= 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_check_interval_seconds must be positive") + } + if c.Gateway.UsageRecord.AutoScaleCooldownSeconds < 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_cooldown_seconds must be non-negative") + } + } + if c.Gateway.UserGroupRateCacheTTLSeconds <= 0 { + return fmt.Errorf("gateway.user_group_rate_cache_ttl_seconds must be positive") + } + if c.Gateway.ModelsListCacheTTLSeconds < 10 || c.Gateway.ModelsListCacheTTLSeconds > 30 { + return fmt.Errorf("gateway.models_list_cache_ttl_seconds must be between 10-30") + } if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 { return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive") } @@ -1420,6 +2361,6 @@ func warnIfInsecureURL(field, raw string) { return } if strings.EqualFold(u.Scheme, "http") { - log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field) + slog.Warn("url uses http scheme; use https in production to avoid token leakage", "field", field) } } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index f734619f..e3b592e2 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -6,8 +6,28 @@ import ( "time" "github.com/spf13/viper" + "github.com/stretchr/testify/require" ) +func resetViperWithJWTSecret(t *testing.T) { + t.Helper() + viper.Reset() + t.Setenv("JWT_SECRET", strings.Repeat("x", 32)) +} + +func TestLoadForBootstrapAllowsMissingJWTSecret(t *testing.T) { + viper.Reset() + t.Setenv("JWT_SECRET", "") + + cfg, err := LoadForBootstrap() + if err != nil { + t.Fatalf("LoadForBootstrap() error: %v", err) + } + if cfg.JWT.Secret != "" { + t.Fatalf("LoadForBootstrap() should keep empty jwt.secret during bootstrap") + } +} + func TestNormalizeRunMode(t *testing.T) { tests := []struct { input string @@ -29,7 +49,7 @@ func TestNormalizeRunMode(t *testing.T) { } func TestLoadDefaultSchedulingConfig(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -56,8 +76,141 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) { } } +func TestLoadDefaultOpenAIWSConfig(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if !cfg.Gateway.OpenAIWS.Enabled { + t.Fatalf("Gateway.OpenAIWS.Enabled = false, want true") + } + if !cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 { + t.Fatalf("Gateway.OpenAIWS.ResponsesWebsocketsV2 = false, want true") + } + if cfg.Gateway.OpenAIWS.ResponsesWebsockets { + t.Fatalf("Gateway.OpenAIWS.ResponsesWebsockets = true, want false") + } + if !cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled { + t.Fatalf("Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = false, want true") + } + if cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor != 1.0 { + t.Fatalf("Gateway.OpenAIWS.OAuthMaxConnsFactor = %v, want 1.0", cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor) + } + if cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor != 1.0 { + t.Fatalf("Gateway.OpenAIWS.APIKeyMaxConnsFactor = %v, want 1.0", cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor) + } + if cfg.Gateway.OpenAIWS.StickySessionTTLSeconds != 3600 { + t.Fatalf("Gateway.OpenAIWS.StickySessionTTLSeconds = %d, want 3600", cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) + } + if !cfg.Gateway.OpenAIWS.SessionHashReadOldFallback { + t.Fatalf("Gateway.OpenAIWS.SessionHashReadOldFallback = false, want true") + } + if !cfg.Gateway.OpenAIWS.SessionHashDualWriteOld { + t.Fatalf("Gateway.OpenAIWS.SessionHashDualWriteOld = false, want true") + } + if !cfg.Gateway.OpenAIWS.MetadataBridgeEnabled { + t.Fatalf("Gateway.OpenAIWS.MetadataBridgeEnabled = false, want true") + } + if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds != 3600 { + t.Fatalf("Gateway.OpenAIWS.StickyResponseIDTTLSeconds = %d, want 3600", cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds) + } + if cfg.Gateway.OpenAIWS.FallbackCooldownSeconds != 30 { + t.Fatalf("Gateway.OpenAIWS.FallbackCooldownSeconds = %d, want 30", cfg.Gateway.OpenAIWS.FallbackCooldownSeconds) + } + if cfg.Gateway.OpenAIWS.EventFlushBatchSize != 1 { + t.Fatalf("Gateway.OpenAIWS.EventFlushBatchSize = %d, want 1", cfg.Gateway.OpenAIWS.EventFlushBatchSize) + } + if cfg.Gateway.OpenAIWS.EventFlushIntervalMS != 10 { + t.Fatalf("Gateway.OpenAIWS.EventFlushIntervalMS = %d, want 10", cfg.Gateway.OpenAIWS.EventFlushIntervalMS) + } + if cfg.Gateway.OpenAIWS.PrewarmCooldownMS != 300 { + t.Fatalf("Gateway.OpenAIWS.PrewarmCooldownMS = %d, want 300", cfg.Gateway.OpenAIWS.PrewarmCooldownMS) + } + if cfg.Gateway.OpenAIWS.RetryBackoffInitialMS != 120 { + t.Fatalf("Gateway.OpenAIWS.RetryBackoffInitialMS = %d, want 120", cfg.Gateway.OpenAIWS.RetryBackoffInitialMS) + } + if cfg.Gateway.OpenAIWS.RetryBackoffMaxMS != 2000 { + t.Fatalf("Gateway.OpenAIWS.RetryBackoffMaxMS = %d, want 2000", cfg.Gateway.OpenAIWS.RetryBackoffMaxMS) + } + if cfg.Gateway.OpenAIWS.RetryJitterRatio != 0.2 { + t.Fatalf("Gateway.OpenAIWS.RetryJitterRatio = %v, want 0.2", cfg.Gateway.OpenAIWS.RetryJitterRatio) + } + if cfg.Gateway.OpenAIWS.RetryTotalBudgetMS != 5000 { + t.Fatalf("Gateway.OpenAIWS.RetryTotalBudgetMS = %d, want 5000", cfg.Gateway.OpenAIWS.RetryTotalBudgetMS) + } + if cfg.Gateway.OpenAIWS.PayloadLogSampleRate != 0.2 { + t.Fatalf("Gateway.OpenAIWS.PayloadLogSampleRate = %v, want 0.2", cfg.Gateway.OpenAIWS.PayloadLogSampleRate) + } + if !cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn { + t.Fatalf("Gateway.OpenAIWS.StoreDisabledForceNewConn = false, want true") + } + if cfg.Gateway.OpenAIWS.StoreDisabledConnMode != "strict" { + t.Fatalf("Gateway.OpenAIWS.StoreDisabledConnMode = %q, want %q", cfg.Gateway.OpenAIWS.StoreDisabledConnMode, "strict") + } + if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled { + t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false") + } + if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" { + t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared") + } +} + +func TestLoadOpenAIWSStickyTTLCompatibility(t *testing.T) { + resetViperWithJWTSecret(t) + t.Setenv("GATEWAY_OPENAI_WS_STICKY_RESPONSE_ID_TTL_SECONDS", "0") + t.Setenv("GATEWAY_OPENAI_WS_STICKY_PREVIOUS_RESPONSE_TTL_SECONDS", "7200") + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds != 7200 { + t.Fatalf("StickyResponseIDTTLSeconds = %d, want 7200", cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds) + } +} + +func TestLoadDefaultIdempotencyConfig(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if !cfg.Idempotency.ObserveOnly { + t.Fatalf("Idempotency.ObserveOnly = false, want true") + } + if cfg.Idempotency.DefaultTTLSeconds != 86400 { + t.Fatalf("Idempotency.DefaultTTLSeconds = %d, want 86400", cfg.Idempotency.DefaultTTLSeconds) + } + if cfg.Idempotency.SystemOperationTTLSeconds != 3600 { + t.Fatalf("Idempotency.SystemOperationTTLSeconds = %d, want 3600", cfg.Idempotency.SystemOperationTTLSeconds) + } +} + +func TestLoadIdempotencyConfigFromEnv(t *testing.T) { + resetViperWithJWTSecret(t) + t.Setenv("IDEMPOTENCY_OBSERVE_ONLY", "false") + t.Setenv("IDEMPOTENCY_DEFAULT_TTL_SECONDS", "600") + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + if cfg.Idempotency.ObserveOnly { + t.Fatalf("Idempotency.ObserveOnly = true, want false") + } + if cfg.Idempotency.DefaultTTLSeconds != 600 { + t.Fatalf("Idempotency.DefaultTTLSeconds = %d, want 600", cfg.Idempotency.DefaultTTLSeconds) + } +} + func TestLoadSchedulingConfigFromEnv(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5") cfg, err := Load() @@ -71,7 +224,7 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) { } func TestLoadDefaultSecurityToggles(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -87,13 +240,69 @@ func TestLoadDefaultSecurityToggles(t *testing.T) { if !cfg.Security.URLAllowlist.AllowPrivateHosts { t.Fatalf("URLAllowlist.AllowPrivateHosts = false, want true") } - if cfg.Security.ResponseHeaders.Enabled { - t.Fatalf("ResponseHeaders.Enabled = true, want false") + if !cfg.Security.ResponseHeaders.Enabled { + t.Fatalf("ResponseHeaders.Enabled = false, want true") + } +} + +func TestLoadDefaultServerMode(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.Server.Mode != "release" { + t.Fatalf("Server.Mode = %q, want %q", cfg.Server.Mode, "release") + } +} + +func TestLoadDefaultJWTAccessTokenExpireMinutes(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.JWT.ExpireHour != 24 { + t.Fatalf("JWT.ExpireHour = %d, want 24", cfg.JWT.ExpireHour) + } + if cfg.JWT.AccessTokenExpireMinutes != 0 { + t.Fatalf("JWT.AccessTokenExpireMinutes = %d, want 0", cfg.JWT.AccessTokenExpireMinutes) + } +} + +func TestLoadJWTAccessTokenExpireMinutesFromEnv(t *testing.T) { + resetViperWithJWTSecret(t) + t.Setenv("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", "90") + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.JWT.AccessTokenExpireMinutes != 90 { + t.Fatalf("JWT.AccessTokenExpireMinutes = %d, want 90", cfg.JWT.AccessTokenExpireMinutes) + } +} + +func TestLoadDefaultDatabaseSSLMode(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.Database.SSLMode != "prefer" { + t.Fatalf("Database.SSLMode = %q, want %q", cfg.Database.SSLMode, "prefer") } } func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -118,7 +327,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) { } func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -143,7 +352,7 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { } func TestLoadDefaultDashboardCacheConfig(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -168,7 +377,7 @@ func TestLoadDefaultDashboardCacheConfig(t *testing.T) { } func TestValidateDashboardCacheConfigEnabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -188,7 +397,7 @@ func TestValidateDashboardCacheConfigEnabled(t *testing.T) { } func TestValidateDashboardCacheConfigDisabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -207,7 +416,7 @@ func TestValidateDashboardCacheConfigDisabled(t *testing.T) { } func TestLoadDefaultDashboardAggregationConfig(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -244,7 +453,7 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) { } func TestValidateDashboardAggregationConfigDisabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -263,7 +472,7 @@ func TestValidateDashboardAggregationConfigDisabled(t *testing.T) { } func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -282,7 +491,7 @@ func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) { } func TestLoadDefaultUsageCleanupConfig(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -307,7 +516,7 @@ func TestLoadDefaultUsageCleanupConfig(t *testing.T) { } func TestValidateUsageCleanupConfigEnabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -326,7 +535,7 @@ func TestValidateUsageCleanupConfigEnabled(t *testing.T) { } func TestValidateUsageCleanupConfigDisabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -424,6 +633,40 @@ func TestValidateAbsoluteHTTPURL(t *testing.T) { } } +func TestValidateServerFrontendURL(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Server.FrontendURL = "https://example.com" + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() frontend_url valid error: %v", err) + } + + cfg.Server.FrontendURL = "https://example.com/path" + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() frontend_url with path valid error: %v", err) + } + + cfg.Server.FrontendURL = "https://example.com?utm=1" + if err := cfg.Validate(); err == nil { + t.Fatalf("Validate() should reject server.frontend_url with query") + } + + cfg.Server.FrontendURL = "https://user:pass@example.com" + if err := cfg.Validate(); err == nil { + t.Fatalf("Validate() should reject server.frontend_url with userinfo") + } + + cfg.Server.FrontendURL = "/relative" + if err := cfg.Validate(); err == nil { + t.Fatalf("Validate() should reject relative server.frontend_url") + } +} + func TestValidateFrontendRedirectURL(t *testing.T) { if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil { t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err) @@ -445,6 +688,7 @@ func TestValidateFrontendRedirectURL(t *testing.T) { func TestWarnIfInsecureURL(t *testing.T) { warnIfInsecureURL("test", "http://example.com") warnIfInsecureURL("test", "bad://url") + warnIfInsecureURL("test", "://invalid") } func TestGenerateJWTSecretDefaultLength(t *testing.T) { @@ -458,7 +702,7 @@ func TestGenerateJWTSecretDefaultLength(t *testing.T) { } func TestValidateOpsCleanupScheduleRequired(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -476,7 +720,7 @@ func TestValidateOpsCleanupScheduleRequired(t *testing.T) { } func TestValidateConcurrencyPingInterval(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -493,14 +737,14 @@ func TestValidateConcurrencyPingInterval(t *testing.T) { } func TestProvideConfig(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) if _, err := ProvideConfig(); err != nil { t.Fatalf("ProvideConfig() error: %v", err) } } func TestValidateConfigWithLinuxDoEnabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -544,6 +788,24 @@ func TestGenerateJWTSecretWithLength(t *testing.T) { } } +func TestDatabaseDSNWithTimezone_WithPassword(t *testing.T) { + d := &DatabaseConfig{ + Host: "localhost", + Port: 5432, + User: "u", + Password: "p", + DBName: "db", + SSLMode: "prefer", + } + got := d.DSNWithTimezone("UTC") + if !strings.Contains(got, "password=p") { + t.Fatalf("DSNWithTimezone should include password: %q", got) + } + if !strings.Contains(got, "TimeZone=UTC") { + t.Fatalf("DSNWithTimezone should include TimeZone=UTC: %q", got) + } +} + func TestValidateAbsoluteHTTPURLMissingHost(t *testing.T) { if err := ValidateAbsoluteHTTPURL("https://"); err == nil { t.Fatalf("ValidateAbsoluteHTTPURL should reject missing host") @@ -566,10 +828,35 @@ func TestWarnIfInsecureURLHTTPS(t *testing.T) { warnIfInsecureURL("secure", "https://example.com") } +func TestValidateJWTSecret_UTF8Bytes(t *testing.T) { + resetViperWithJWTSecret(t) + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + // 31 bytes (< 32) even though it's 31 characters. + cfg.JWT.Secret = strings.Repeat("a", 31) + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() should reject 31-byte secret") + } + if !strings.Contains(err.Error(), "at least 32 bytes") { + t.Fatalf("Validate() error = %v", err) + } + + // 32 bytes OK. + cfg.JWT.Secret = strings.Repeat("a", 32) + err = cfg.Validate() + if err != nil { + t.Fatalf("Validate() should accept 32-byte secret: %v", err) + } +} + func TestValidateConfigErrors(t *testing.T) { buildValid := func(t *testing.T) *Config { t.Helper() - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { t.Fatalf("Load() error: %v", err) @@ -582,6 +869,26 @@ func TestValidateConfigErrors(t *testing.T) { mutate func(*Config) wantErr string }{ + { + name: "jwt secret required", + mutate: func(c *Config) { c.JWT.Secret = "" }, + wantErr: "jwt.secret is required", + }, + { + name: "jwt secret min bytes", + mutate: func(c *Config) { c.JWT.Secret = strings.Repeat("a", 31) }, + wantErr: "jwt.secret must be at least 32 bytes", + }, + { + name: "subscription maintenance worker_count non-negative", + mutate: func(c *Config) { c.SubscriptionMaintenance.WorkerCount = -1 }, + wantErr: "subscription_maintenance.worker_count", + }, + { + name: "subscription maintenance queue_size non-negative", + mutate: func(c *Config) { c.SubscriptionMaintenance.QueueSize = -1 }, + wantErr: "subscription_maintenance.queue_size", + }, { name: "jwt expire hour positive", mutate: func(c *Config) { c.JWT.ExpireHour = 0 }, @@ -592,6 +899,11 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.JWT.ExpireHour = 200 }, wantErr: "jwt.expire_hour must be <= 168", }, + { + name: "jwt access token expire minutes non-negative", + mutate: func(c *Config) { c.JWT.AccessTokenExpireMinutes = -1 }, + wantErr: "jwt.access_token_expire_minutes must be non-negative", + }, { name: "csp policy required", mutate: func(c *Config) { c.Security.CSP.Enabled = true; c.Security.CSP.Policy = "" }, @@ -779,6 +1091,16 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.Gateway.StreamKeepaliveInterval = 4 }, wantErr: "gateway.stream_keepalive_interval", }, + { + name: "gateway openai ws oauth max conns factor", + mutate: func(c *Config) { c.Gateway.OpenAIWS.OAuthMaxConnsFactor = 0 }, + wantErr: "gateway.openai_ws.oauth_max_conns_factor", + }, + { + name: "gateway openai ws apikey max conns factor", + mutate: func(c *Config) { c.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0 }, + wantErr: "gateway.openai_ws.apikey_max_conns_factor", + }, { name: "gateway stream data interval range", mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 }, @@ -799,6 +1121,84 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.Gateway.MaxLineSize = -1 }, wantErr: "gateway.max_line_size must be non-negative", }, + { + name: "gateway usage record worker count", + mutate: func(c *Config) { c.Gateway.UsageRecord.WorkerCount = 0 }, + wantErr: "gateway.usage_record.worker_count", + }, + { + name: "gateway usage record queue size", + mutate: func(c *Config) { c.Gateway.UsageRecord.QueueSize = 0 }, + wantErr: "gateway.usage_record.queue_size", + }, + { + name: "gateway usage record timeout", + mutate: func(c *Config) { c.Gateway.UsageRecord.TaskTimeoutSeconds = 0 }, + wantErr: "gateway.usage_record.task_timeout_seconds", + }, + { + name: "gateway usage record overflow policy", + mutate: func(c *Config) { c.Gateway.UsageRecord.OverflowPolicy = "invalid" }, + wantErr: "gateway.usage_record.overflow_policy", + }, + { + name: "gateway usage record sample percent range", + mutate: func(c *Config) { c.Gateway.UsageRecord.OverflowSamplePercent = 101 }, + wantErr: "gateway.usage_record.overflow_sample_percent", + }, + { + name: "gateway usage record sample percent required for sample policy", + mutate: func(c *Config) { + c.Gateway.UsageRecord.OverflowPolicy = UsageRecordOverflowPolicySample + c.Gateway.UsageRecord.OverflowSamplePercent = 0 + }, + wantErr: "gateway.usage_record.overflow_sample_percent must be positive", + }, + { + name: "gateway usage record auto scale max gte min", + mutate: func(c *Config) { + c.Gateway.UsageRecord.AutoScaleMinWorkers = 256 + c.Gateway.UsageRecord.AutoScaleMaxWorkers = 128 + }, + wantErr: "gateway.usage_record.auto_scale_max_workers", + }, + { + name: "gateway usage record worker in auto scale range", + mutate: func(c *Config) { + c.Gateway.UsageRecord.AutoScaleMinWorkers = 200 + c.Gateway.UsageRecord.AutoScaleMaxWorkers = 300 + c.Gateway.UsageRecord.WorkerCount = 128 + }, + wantErr: "gateway.usage_record.worker_count must be between auto_scale_min_workers and auto_scale_max_workers", + }, + { + name: "gateway usage record auto scale queue thresholds order", + mutate: func(c *Config) { + c.Gateway.UsageRecord.AutoScaleUpQueuePercent = 50 + c.Gateway.UsageRecord.AutoScaleDownQueuePercent = 50 + }, + wantErr: "gateway.usage_record.auto_scale_down_queue_percent must be less", + }, + { + name: "gateway usage record auto scale up step", + mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleUpStep = 0 }, + wantErr: "gateway.usage_record.auto_scale_up_step", + }, + { + name: "gateway usage record auto scale interval", + mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0 }, + wantErr: "gateway.usage_record.auto_scale_check_interval_seconds", + }, + { + name: "gateway user group rate cache ttl", + mutate: func(c *Config) { c.Gateway.UserGroupRateCacheTTLSeconds = 0 }, + wantErr: "gateway.user_group_rate_cache_ttl_seconds", + }, + { + name: "gateway models list cache ttl range", + mutate: func(c *Config) { c.Gateway.ModelsListCacheTTLSeconds = 31 }, + wantErr: "gateway.models_list_cache_ttl_seconds", + }, { name: "gateway scheduling sticky waiting", mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 }, @@ -822,6 +1222,37 @@ func TestValidateConfigErrors(t *testing.T) { }, wantErr: "gateway.scheduling.outbox_lag_rebuild_seconds", }, + { + name: "log level invalid", + mutate: func(c *Config) { c.Log.Level = "trace" }, + wantErr: "log.level", + }, + { + name: "log format invalid", + mutate: func(c *Config) { c.Log.Format = "plain" }, + wantErr: "log.format", + }, + { + name: "log output disabled", + mutate: func(c *Config) { + c.Log.Output.ToStdout = false + c.Log.Output.ToFile = false + }, + wantErr: "log.output.to_stdout and log.output.to_file cannot both be false", + }, + { + name: "log rotation size", + mutate: func(c *Config) { c.Log.Rotation.MaxSizeMB = 0 }, + wantErr: "log.rotation.max_size_mb", + }, + { + name: "log sampling enabled invalid", + mutate: func(c *Config) { + c.Log.Sampling.Enabled = true + c.Log.Sampling.Initial = 0 + }, + wantErr: "log.sampling.initial", + }, { name: "ops metrics collector ttl", mutate: func(c *Config) { c.Ops.MetricsCollectorCache.TTL = -1 }, @@ -850,3 +1281,393 @@ func TestValidateConfigErrors(t *testing.T) { }) } } + +func TestValidateConfig_OpenAIWSRules(t *testing.T) { + buildValid := func(t *testing.T) *Config { + t.Helper() + resetViperWithJWTSecret(t) + cfg, err := Load() + require.NoError(t, err) + return cfg + } + + t.Run("sticky response id ttl 兼容旧键回填", func(t *testing.T) { + cfg := buildValid(t) + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 0 + cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = 7200 + + require.NoError(t, cfg.Validate()) + require.Equal(t, 7200, cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds) + }) + + cases := []struct { + name string + mutate func(*Config) + wantErr string + }{ + { + name: "max_conns_per_account 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.MaxConnsPerAccount = 0 }, + wantErr: "gateway.openai_ws.max_conns_per_account", + }, + { + name: "min_idle_per_account 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.MinIdlePerAccount = -1 }, + wantErr: "gateway.openai_ws.min_idle_per_account", + }, + { + name: "max_idle_per_account 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.MaxIdlePerAccount = -1 }, + wantErr: "gateway.openai_ws.max_idle_per_account", + }, + { + name: "min_idle_per_account 不能大于 max_idle_per_account", + mutate: func(c *Config) { + c.Gateway.OpenAIWS.MinIdlePerAccount = 3 + c.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + }, + wantErr: "gateway.openai_ws.min_idle_per_account must be <= max_idle_per_account", + }, + { + name: "max_idle_per_account 不能大于 max_conns_per_account", + mutate: func(c *Config) { + c.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + c.Gateway.OpenAIWS.MinIdlePerAccount = 1 + c.Gateway.OpenAIWS.MaxIdlePerAccount = 3 + }, + wantErr: "gateway.openai_ws.max_idle_per_account must be <= max_conns_per_account", + }, + { + name: "dial_timeout_seconds 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.DialTimeoutSeconds = 0 }, + wantErr: "gateway.openai_ws.dial_timeout_seconds", + }, + { + name: "read_timeout_seconds 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.ReadTimeoutSeconds = 0 }, + wantErr: "gateway.openai_ws.read_timeout_seconds", + }, + { + name: "write_timeout_seconds 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.WriteTimeoutSeconds = 0 }, + wantErr: "gateway.openai_ws.write_timeout_seconds", + }, + { + name: "pool_target_utilization 必须在 (0,1]", + mutate: func(c *Config) { c.Gateway.OpenAIWS.PoolTargetUtilization = 0 }, + wantErr: "gateway.openai_ws.pool_target_utilization", + }, + { + name: "queue_limit_per_conn 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.QueueLimitPerConn = 0 }, + wantErr: "gateway.openai_ws.queue_limit_per_conn", + }, + { + name: "fallback_cooldown_seconds 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.FallbackCooldownSeconds = -1 }, + wantErr: "gateway.openai_ws.fallback_cooldown_seconds", + }, + { + name: "store_disabled_conn_mode 必须为 strict|adaptive|off", + mutate: func(c *Config) { c.Gateway.OpenAIWS.StoreDisabledConnMode = "invalid" }, + wantErr: "gateway.openai_ws.store_disabled_conn_mode", + }, + { + name: "ingress_mode_default 必须为 off|shared|dedicated", + mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" }, + wantErr: "gateway.openai_ws.ingress_mode_default", + }, + { + name: "payload_log_sample_rate 必须在 [0,1] 范围内", + mutate: func(c *Config) { c.Gateway.OpenAIWS.PayloadLogSampleRate = 1.2 }, + wantErr: "gateway.openai_ws.payload_log_sample_rate", + }, + { + name: "retry_total_budget_ms 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.RetryTotalBudgetMS = -1 }, + wantErr: "gateway.openai_ws.retry_total_budget_ms", + }, + { + name: "lb_top_k 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.LBTopK = 0 }, + wantErr: "gateway.openai_ws.lb_top_k", + }, + { + name: "sticky_session_ttl_seconds 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.StickySessionTTLSeconds = 0 }, + wantErr: "gateway.openai_ws.sticky_session_ttl_seconds", + }, + { + name: "sticky_response_id_ttl_seconds 必须为正数", + mutate: func(c *Config) { + c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 0 + c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = 0 + }, + wantErr: "gateway.openai_ws.sticky_response_id_ttl_seconds", + }, + { + name: "sticky_previous_response_ttl_seconds 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = -1 }, + wantErr: "gateway.openai_ws.sticky_previous_response_ttl_seconds", + }, + { + name: "scheduler_score_weights 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = -0.1 }, + wantErr: "gateway.openai_ws.scheduler_score_weights.* must be non-negative", + }, + { + name: "scheduler_score_weights 不能全为 0", + mutate: func(c *Config) { + c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0 + c.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0 + c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0 + c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0 + c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0 + }, + wantErr: "gateway.openai_ws.scheduler_score_weights must not all be zero", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + cfg := buildValid(t) + tc.mutate(cfg) + + err := cfg.Validate() + require.Error(t, err) + require.Contains(t, err.Error(), tc.wantErr) + }) + } +} + +func TestValidateConfig_AutoScaleDisabledIgnoreAutoScaleFields(t *testing.T) { + resetViperWithJWTSecret(t) + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Gateway.UsageRecord.AutoScaleEnabled = false + cfg.Gateway.UsageRecord.WorkerCount = 64 + + // 自动扩缩容关闭时,这些字段应被忽略,不应导致校验失败。 + cfg.Gateway.UsageRecord.AutoScaleMinWorkers = 0 + cfg.Gateway.UsageRecord.AutoScaleMaxWorkers = 0 + cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent = 0 + cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent = 100 + cfg.Gateway.UsageRecord.AutoScaleUpStep = 0 + cfg.Gateway.UsageRecord.AutoScaleDownStep = 0 + cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0 + cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds = -1 + + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() should ignore auto scale fields when disabled: %v", err) + } +} + +func TestValidateConfig_LogRequiredAndRotationBounds(t *testing.T) { + resetViperWithJWTSecret(t) + + cases := []struct { + name string + mutate func(*Config) + wantErr string + }{ + { + name: "log level required", + mutate: func(c *Config) { + c.Log.Level = "" + }, + wantErr: "log.level is required", + }, + { + name: "log format required", + mutate: func(c *Config) { + c.Log.Format = "" + }, + wantErr: "log.format is required", + }, + { + name: "log stacktrace required", + mutate: func(c *Config) { + c.Log.StacktraceLevel = "" + }, + wantErr: "log.stacktrace_level is required", + }, + { + name: "log max backups non-negative", + mutate: func(c *Config) { + c.Log.Rotation.MaxBackups = -1 + }, + wantErr: "log.rotation.max_backups must be non-negative", + }, + { + name: "log max age non-negative", + mutate: func(c *Config) { + c.Log.Rotation.MaxAgeDays = -1 + }, + wantErr: "log.rotation.max_age_days must be non-negative", + }, + { + name: "sampling thereafter non-negative when disabled", + mutate: func(c *Config) { + c.Log.Sampling.Enabled = false + c.Log.Sampling.Thereafter = -1 + }, + wantErr: "log.sampling.thereafter must be non-negative", + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + tt.mutate(cfg) + err = cfg.Validate() + if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("Validate() error = %v, want %q", err, tt.wantErr) + } + }) + } +} + +func TestSoraCurlCFFISidecarDefaults(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if !cfg.Sora.Client.CurlCFFISidecar.Enabled { + t.Fatalf("Sora curl_cffi sidecar should be enabled by default") + } + if cfg.Sora.Client.CloudflareChallengeCooldownSeconds <= 0 { + t.Fatalf("Sora cloudflare challenge cooldown should be positive by default") + } + if cfg.Sora.Client.CurlCFFISidecar.BaseURL == "" { + t.Fatalf("Sora curl_cffi sidecar base_url should not be empty by default") + } + if cfg.Sora.Client.CurlCFFISidecar.Impersonate == "" { + t.Fatalf("Sora curl_cffi sidecar impersonate should not be empty by default") + } + if !cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled { + t.Fatalf("Sora curl_cffi sidecar session reuse should be enabled by default") + } + if cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds <= 0 { + t.Fatalf("Sora curl_cffi sidecar session ttl should be positive by default") + } +} + +func TestValidateSoraCurlCFFISidecarRequired(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Sora.Client.CurlCFFISidecar.Enabled = false + err = cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.enabled must be true") { + t.Fatalf("Validate() error = %v, want sidecar enabled error", err) + } +} + +func TestValidateSoraCurlCFFISidecarBaseURLRequired(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Sora.Client.CurlCFFISidecar.BaseURL = " " + err = cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.base_url is required") { + t.Fatalf("Validate() error = %v, want sidecar base_url required error", err) + } +} + +func TestValidateSoraCurlCFFISidecarSessionTTLNonNegative(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds = -1 + err = cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") { + t.Fatalf("Validate() error = %v, want sidecar session ttl error", err) + } +} + +func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Sora.Client.CloudflareChallengeCooldownSeconds = -1 + err = cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") { + t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err) + } +} + +func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) { + resetViperWithJWTSecret(t) + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + if cfg.Gateway.UsageRecord.WorkerCount != 128 { + t.Fatalf("worker_count = %d, want 128", cfg.Gateway.UsageRecord.WorkerCount) + } + if cfg.Gateway.UsageRecord.QueueSize != 16384 { + t.Fatalf("queue_size = %d, want 16384", cfg.Gateway.UsageRecord.QueueSize) + } + if cfg.Gateway.UsageRecord.TaskTimeoutSeconds != 5 { + t.Fatalf("task_timeout_seconds = %d, want 5", cfg.Gateway.UsageRecord.TaskTimeoutSeconds) + } + if cfg.Gateway.UsageRecord.OverflowPolicy != UsageRecordOverflowPolicySample { + t.Fatalf("overflow_policy = %s, want %s", cfg.Gateway.UsageRecord.OverflowPolicy, UsageRecordOverflowPolicySample) + } + if cfg.Gateway.UsageRecord.OverflowSamplePercent != 10 { + t.Fatalf("overflow_sample_percent = %d, want 10", cfg.Gateway.UsageRecord.OverflowSamplePercent) + } + if !cfg.Gateway.UsageRecord.AutoScaleEnabled { + t.Fatalf("auto_scale_enabled = false, want true") + } + if cfg.Gateway.UsageRecord.AutoScaleMinWorkers != 128 { + t.Fatalf("auto_scale_min_workers = %d, want 128", cfg.Gateway.UsageRecord.AutoScaleMinWorkers) + } + if cfg.Gateway.UsageRecord.AutoScaleMaxWorkers != 512 { + t.Fatalf("auto_scale_max_workers = %d, want 512", cfg.Gateway.UsageRecord.AutoScaleMaxWorkers) + } + if cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent != 70 { + t.Fatalf("auto_scale_up_queue_percent = %d, want 70", cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent) + } + if cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent != 15 { + t.Fatalf("auto_scale_down_queue_percent = %d, want 15", cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent) + } + if cfg.Gateway.UsageRecord.AutoScaleUpStep != 32 { + t.Fatalf("auto_scale_up_step = %d, want 32", cfg.Gateway.UsageRecord.AutoScaleUpStep) + } + if cfg.Gateway.UsageRecord.AutoScaleDownStep != 16 { + t.Fatalf("auto_scale_down_step = %d, want 16", cfg.Gateway.UsageRecord.AutoScaleDownStep) + } + if cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds != 3 { + t.Fatalf("auto_scale_check_interval_seconds = %d, want 3", cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds) + } + if cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds != 10 { + t.Fatalf("auto_scale_cooldown_seconds = %d, want 10", cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds) + } +} diff --git a/backend/internal/config/wire.go b/backend/internal/config/wire.go index ec26c401..bf6b3bd6 100644 --- a/backend/internal/config/wire.go +++ b/backend/internal/config/wire.go @@ -9,5 +9,5 @@ var ProviderSet = wire.NewSet( // ProvideConfig 提供应用配置 func ProvideConfig() (*Config, error) { - return Load() + return LoadForBootstrap() } diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index 05b5adc1..d7bb50fc 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -22,6 +22,7 @@ const ( PlatformOpenAI = "openai" PlatformGemini = "gemini" PlatformAntigravity = "antigravity" + PlatformSora = "sora" ) // Account type constants @@ -73,6 +74,7 @@ var DefaultAntigravityModelMapping = map[string]string{ "claude-opus-4-6-thinking": "claude-opus-4-6-thinking", // 官方模型 "claude-opus-4-6": "claude-opus-4-6-thinking", // 简称映射 "claude-opus-4-5-thinking": "claude-opus-4-6-thinking", // 迁移旧模型 + "claude-sonnet-4-6": "claude-sonnet-4-6", "claude-sonnet-4-5": "claude-sonnet-4-5", "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", // Claude 详细版本 ID 映射 @@ -87,14 +89,24 @@ var DefaultAntigravityModelMapping = map[string]string{ "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", "gemini-2.5-pro": "gemini-2.5-pro", // Gemini 3 白名单 - "gemini-3-flash": "gemini-3-flash", - "gemini-3-pro-high": "gemini-3-pro-high", - "gemini-3-pro-low": "gemini-3-pro-low", - "gemini-3-pro-image": "gemini-3-pro-image", + "gemini-3-flash": "gemini-3-flash", + "gemini-3-pro-high": "gemini-3-pro-high", + "gemini-3-pro-low": "gemini-3-pro-low", // Gemini 3 preview 映射 - "gemini-3-flash-preview": "gemini-3-flash", - "gemini-3-pro-preview": "gemini-3-pro-high", - "gemini-3-pro-image-preview": "gemini-3-pro-image", + "gemini-3-flash-preview": "gemini-3-flash", + "gemini-3-pro-preview": "gemini-3-pro-high", + // Gemini 3.1 白名单 + "gemini-3.1-pro-high": "gemini-3.1-pro-high", + "gemini-3.1-pro-low": "gemini-3.1-pro-low", + // Gemini 3.1 preview 映射 + "gemini-3.1-pro-preview": "gemini-3.1-pro-high", + // Gemini 3.1 image 白名单 + "gemini-3.1-flash-image": "gemini-3.1-flash-image", + // Gemini 3.1 image preview 映射 + "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image", + // Gemini 3 image 兼容映射(向 3.1 image 迁移) + "gemini-3-pro-image": "gemini-3.1-flash-image", + "gemini-3-pro-image-preview": "gemini-3.1-flash-image", // 其他官方模型 "gpt-oss-120b-medium": "gpt-oss-120b-medium", "tab_flash_lite_preview": "tab_flash_lite_preview", diff --git a/backend/internal/domain/constants_test.go b/backend/internal/domain/constants_test.go new file mode 100644 index 00000000..29605ac6 --- /dev/null +++ b/backend/internal/domain/constants_test.go @@ -0,0 +1,24 @@ +package domain + +import "testing" + +func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) { + t.Parallel() + + cases := map[string]string{ + "gemini-3.1-flash-image": "gemini-3.1-flash-image", + "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image", + "gemini-3-pro-image": "gemini-3.1-flash-image", + "gemini-3-pro-image-preview": "gemini-3.1-flash-image", + } + + for from, want := range cases { + got, ok := DefaultAntigravityModelMapping[from] + if !ok { + t.Fatalf("expected mapping for %q to exist", from) + } + if got != want { + t.Fatalf("unexpected mapping for %q: got %q want %q", from, got, want) + } + } +} diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go index 34397696..4ce17219 100644 --- a/backend/internal/handler/admin/account_data.go +++ b/backend/internal/handler/admin/account_data.go @@ -175,22 +175,28 @@ func (h *AccountHandler) ImportData(c *gin.Context) { return } - dataPayload := req.Data - if err := validateDataHeader(dataPayload); err != nil { + if err := validateDataHeader(req.Data); err != nil { response.BadRequest(c, err.Error()) return } + executeAdminIdempotentJSON(c, "admin.accounts.import_data", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + return h.importData(ctx, req) + }) +} + +func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest) (DataImportResult, error) { skipDefaultGroupBind := true if req.SkipDefaultGroupBind != nil { skipDefaultGroupBind = *req.SkipDefaultGroupBind } + dataPayload := req.Data result := DataImportResult{} - existingProxies, err := h.listAllProxies(c.Request.Context()) + + existingProxies, err := h.listAllProxies(ctx) if err != nil { - response.ErrorFrom(c, err) - return + return result, err } proxyKeyToID := make(map[string]int64, len(existingProxies)) @@ -221,8 +227,8 @@ func (h *AccountHandler) ImportData(c *gin.Context) { proxyKeyToID[key] = existingID result.ProxyReused++ if normalizedStatus != "" { - if proxy, err := h.adminService.GetProxy(c.Request.Context(), existingID); err == nil && proxy != nil && proxy.Status != normalizedStatus { - _, _ = h.adminService.UpdateProxy(c.Request.Context(), existingID, &service.UpdateProxyInput{ + if proxy, getErr := h.adminService.GetProxy(ctx, existingID); getErr == nil && proxy != nil && proxy.Status != normalizedStatus { + _, _ = h.adminService.UpdateProxy(ctx, existingID, &service.UpdateProxyInput{ Status: normalizedStatus, }) } @@ -230,7 +236,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) { continue } - created, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ + created, createErr := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{ Name: defaultProxyName(item.Name), Protocol: item.Protocol, Host: item.Host, @@ -238,13 +244,13 @@ func (h *AccountHandler) ImportData(c *gin.Context) { Username: item.Username, Password: item.Password, }) - if err != nil { + if createErr != nil { result.ProxyFailed++ result.Errors = append(result.Errors, DataImportError{ Kind: "proxy", Name: item.Name, ProxyKey: key, - Message: err.Error(), + Message: createErr.Error(), }) continue } @@ -252,7 +258,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) { result.ProxyCreated++ if normalizedStatus != "" && normalizedStatus != created.Status { - _, _ = h.adminService.UpdateProxy(c.Request.Context(), created.ID, &service.UpdateProxyInput{ + _, _ = h.adminService.UpdateProxy(ctx, created.ID, &service.UpdateProxyInput{ Status: normalizedStatus, }) } @@ -303,7 +309,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) { SkipDefaultGroupBind: skipDefaultGroupBind, } - if _, err := h.adminService.CreateAccount(c.Request.Context(), accountInput); err != nil { + if _, err := h.adminService.CreateAccount(ctx, accountInput); err != nil { result.AccountFailed++ result.Errors = append(result.Errors, DataImportError{ Kind: "account", @@ -315,7 +321,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) { result.AccountCreated++ } - response.Success(c, result) + return result, nil } func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, error) { diff --git a/backend/internal/handler/admin/account_data_handler_test.go b/backend/internal/handler/admin/account_data_handler_test.go index c8b04c2a..285033a1 100644 --- a/backend/internal/handler/admin/account_data_handler_test.go +++ b/backend/internal/handler/admin/account_data_handler_test.go @@ -64,6 +64,7 @@ func setupAccountDataRouter() (*gin.Engine, *stubAdminService) { nil, nil, nil, + nil, ) router.GET("/api/v1/admin/accounts/data", h.ExportData) diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 0fae04ac..4f93f702 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -2,7 +2,13 @@ package admin import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" "errors" + "fmt" + "net/http" "strconv" "strings" "sync" @@ -10,6 +16,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/handler/dto" + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" @@ -46,6 +53,7 @@ type AccountHandler struct { concurrencyService *service.ConcurrencyService crsSyncService *service.CRSSyncService sessionLimitCache service.SessionLimitCache + rpmCache service.RPMCache tokenCacheInvalidator service.TokenCacheInvalidator } @@ -62,6 +70,7 @@ func NewAccountHandler( concurrencyService *service.ConcurrencyService, crsSyncService *service.CRSSyncService, sessionLimitCache service.SessionLimitCache, + rpmCache service.RPMCache, tokenCacheInvalidator service.TokenCacheInvalidator, ) *AccountHandler { return &AccountHandler{ @@ -76,6 +85,7 @@ func NewAccountHandler( concurrencyService: concurrencyService, crsSyncService: crsSyncService, sessionLimitCache: sessionLimitCache, + rpmCache: rpmCache, tokenCacheInvalidator: tokenCacheInvalidator, } } @@ -133,6 +143,13 @@ type BulkUpdateAccountsRequest struct { ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险 } +// CheckMixedChannelRequest represents check mixed channel risk request +type CheckMixedChannelRequest struct { + Platform string `json:"platform" binding:"required"` + GroupIDs []int64 `json:"group_ids"` + AccountID *int64 `json:"account_id"` +} + // AccountWithConcurrency extends Account with real-time concurrency info type AccountWithConcurrency struct { *dto.Account @@ -140,6 +157,51 @@ type AccountWithConcurrency struct { // 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回 CurrentWindowCost *float64 `json:"current_window_cost,omitempty"` // 当前窗口费用 ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数 + CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数 +} + +func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency { + item := AccountWithConcurrency{ + Account: dto.AccountFromService(account), + CurrentConcurrency: 0, + } + if account == nil { + return item + } + + if h.concurrencyService != nil { + if counts, err := h.concurrencyService.GetAccountConcurrencyBatch(ctx, []int64{account.ID}); err == nil { + item.CurrentConcurrency = counts[account.ID] + } + } + + if account.IsAnthropicOAuthOrSetupToken() { + if h.accountUsageService != nil && account.GetWindowCostLimit() > 0 { + startTime := account.GetCurrentWindowStartTime() + if stats, err := h.accountUsageService.GetAccountWindowStats(ctx, account.ID, startTime); err == nil && stats != nil { + cost := stats.StandardCost + item.CurrentWindowCost = &cost + } + } + + if h.sessionLimitCache != nil && account.GetMaxSessions() > 0 { + idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute + idleTimeouts := map[int64]time.Duration{account.ID: idleTimeout} + if sessions, err := h.sessionLimitCache.GetActiveSessionCountBatch(ctx, []int64{account.ID}, idleTimeouts); err == nil { + if count, ok := sessions[account.ID]; ok { + item.ActiveSessions = &count + } + } + } + + if h.rpmCache != nil && account.GetBaseRPM() > 0 { + if rpm, err := h.rpmCache.GetRPM(ctx, account.ID); err == nil { + item.CurrentRPM = &rpm + } + } + } + + return item } // List handles listing all accounts with pagination @@ -155,6 +217,7 @@ func (h *AccountHandler) List(c *gin.Context) { if len(search) > 100 { search = search[:100] } + lite := parseBoolQueryWithDefault(c.Query("lite"), false) var groupID int64 if groupIDStr := c.Query("group"); groupIDStr != "" { @@ -173,67 +236,81 @@ func (h *AccountHandler) List(c *gin.Context) { accountIDs[i] = acc.ID } - concurrencyCounts, err := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs) - if err != nil { - // Log error but don't fail the request, just use 0 for all - concurrencyCounts = make(map[int64]int) - } - - // 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能) - windowCostAccountIDs := make([]int64, 0) - sessionLimitAccountIDs := make([]int64, 0) - sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置 - for i := range accounts { - acc := &accounts[i] - if acc.IsAnthropicOAuthOrSetupToken() { - if acc.GetWindowCostLimit() > 0 { - windowCostAccountIDs = append(windowCostAccountIDs, acc.ID) - } - if acc.GetMaxSessions() > 0 { - sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID) - sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute - } - } - } - - // 并行获取窗口费用和活跃会话数 + concurrencyCounts := make(map[int64]int) var windowCosts map[int64]float64 var activeSessions map[int64]int - - // 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置) - if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil { - activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts) - if activeSessions == nil { - activeSessions = make(map[int64]int) + var rpmCounts map[int64]int + if !lite { + // Get current concurrency counts for all accounts + if h.concurrencyService != nil { + if cc, ccErr := h.concurrencyService.GetAccountConcurrencyBatch(c.Request.Context(), accountIDs); ccErr == nil && cc != nil { + concurrencyCounts = cc + } } - } - - // 获取窗口费用(并行查询) - if len(windowCostAccountIDs) > 0 { - windowCosts = make(map[int64]float64) - var mu sync.Mutex - g, gctx := errgroup.WithContext(c.Request.Context()) - g.SetLimit(10) // 限制并发数 - + // 识别需要查询窗口费用、会话数和 RPM 的账号(Anthropic OAuth/SetupToken 且启用了相应功能) + windowCostAccountIDs := make([]int64, 0) + sessionLimitAccountIDs := make([]int64, 0) + rpmAccountIDs := make([]int64, 0) + sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置 for i := range accounts { acc := &accounts[i] - if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 { - continue - } - accCopy := acc // 闭包捕获 - g.Go(func() error { - // 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况) - startTime := accCopy.GetCurrentWindowStartTime() - stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime) - if err == nil && stats != nil { - mu.Lock() - windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用 - mu.Unlock() + if acc.IsAnthropicOAuthOrSetupToken() { + if acc.GetWindowCostLimit() > 0 { + windowCostAccountIDs = append(windowCostAccountIDs, acc.ID) } - return nil // 不返回错误,允许部分失败 - }) + if acc.GetMaxSessions() > 0 { + sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID) + sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute + } + if acc.GetBaseRPM() > 0 { + rpmAccountIDs = append(rpmAccountIDs, acc.ID) + } + } + } + + // 获取 RPM 计数(批量查询) + if len(rpmAccountIDs) > 0 && h.rpmCache != nil { + rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs) + if rpmCounts == nil { + rpmCounts = make(map[int64]int) + } + } + + // 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置) + if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil { + activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts) + if activeSessions == nil { + activeSessions = make(map[int64]int) + } + } + + // 获取窗口费用(并行查询) + if len(windowCostAccountIDs) > 0 { + windowCosts = make(map[int64]float64) + var mu sync.Mutex + g, gctx := errgroup.WithContext(c.Request.Context()) + g.SetLimit(10) // 限制并发数 + + for i := range accounts { + acc := &accounts[i] + if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 { + continue + } + accCopy := acc // 闭包捕获 + g.Go(func() error { + // 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况) + startTime := accCopy.GetCurrentWindowStartTime() + stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime) + if err == nil && stats != nil { + mu.Lock() + windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用 + mu.Unlock() + } + return nil // 不返回错误,允许部分失败 + }) + } + _ = g.Wait() } - _ = g.Wait() } // Build response with concurrency info @@ -259,12 +336,84 @@ func (h *AccountHandler) List(c *gin.Context) { } } + // 添加 RPM 计数(仅当启用时) + if rpmCounts != nil { + if rpm, ok := rpmCounts[acc.ID]; ok { + item.CurrentRPM = &rpm + } + } + result[i] = item } + etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search, lite) + if etag != "" { + c.Header("ETag", etag) + c.Header("Vary", "If-None-Match") + if ifNoneMatchMatched(c.GetHeader("If-None-Match"), etag) { + c.Status(http.StatusNotModified) + return + } + } + response.Paginated(c, result, total, page, pageSize) } +func buildAccountsListETag( + items []AccountWithConcurrency, + total int64, + page, pageSize int, + platform, accountType, status, search string, + lite bool, +) string { + payload := struct { + Total int64 `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` + Platform string `json:"platform"` + AccountType string `json:"type"` + Status string `json:"status"` + Search string `json:"search"` + Lite bool `json:"lite"` + Items []AccountWithConcurrency `json:"items"` + }{ + Total: total, + Page: page, + PageSize: pageSize, + Platform: platform, + AccountType: accountType, + Status: status, + Search: search, + Lite: lite, + Items: items, + } + raw, err := json.Marshal(payload) + if err != nil { + return "" + } + sum := sha256.Sum256(raw) + return "\"" + hex.EncodeToString(sum[:]) + "\"" +} + +func ifNoneMatchMatched(ifNoneMatch, etag string) bool { + if etag == "" || ifNoneMatch == "" { + return false + } + for _, token := range strings.Split(ifNoneMatch, ",") { + candidate := strings.TrimSpace(token) + if candidate == "*" { + return true + } + if candidate == etag { + return true + } + if strings.HasPrefix(candidate, "W/") && strings.TrimPrefix(candidate, "W/") == etag { + return true + } + } + return false +} + // GetByID handles getting an account by ID // GET /api/v1/admin/accounts/:id func (h *AccountHandler) GetByID(c *gin.Context) { @@ -280,7 +429,51 @@ func (h *AccountHandler) GetByID(c *gin.Context) { return } - response.Success(c, dto.AccountFromService(account)) + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) +} + +// CheckMixedChannel handles checking mixed channel risk for account-group binding. +// POST /api/v1/admin/accounts/check-mixed-channel +func (h *AccountHandler) CheckMixedChannel(c *gin.Context) { + var req CheckMixedChannelRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if len(req.GroupIDs) == 0 { + response.Success(c, gin.H{"has_risk": false}) + return + } + + accountID := int64(0) + if req.AccountID != nil { + accountID = *req.AccountID + } + + err := h.adminService.CheckMixedChannelRisk(c.Request.Context(), accountID, req.Platform, req.GroupIDs) + if err != nil { + var mixedErr *service.MixedChannelError + if errors.As(err, &mixedErr) { + response.Success(c, gin.H{ + "has_risk": true, + "error": "mixed_channel_warning", + "message": mixedErr.Error(), + "details": gin.H{ + "group_id": mixedErr.GroupID, + "group_name": mixedErr.GroupName, + "current_platform": mixedErr.CurrentPlatform, + "other_platform": mixedErr.OtherPlatform, + }, + }) + return + } + + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"has_risk": false}) } // Create handles creating a new account @@ -295,50 +488,57 @@ func (h *AccountHandler) Create(c *gin.Context) { response.BadRequest(c, "rate_multiplier must be >= 0") return } + // base_rpm 输入校验:负值归零,超过 10000 截断 + sanitizeExtraBaseRPM(req.Extra) // 确定是否跳过混合渠道检查 skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk - account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{ - Name: req.Name, - Notes: req.Notes, - Platform: req.Platform, - Type: req.Type, - Credentials: req.Credentials, - Extra: req.Extra, - ProxyID: req.ProxyID, - Concurrency: req.Concurrency, - Priority: req.Priority, - RateMultiplier: req.RateMultiplier, - GroupIDs: req.GroupIDs, - ExpiresAt: req.ExpiresAt, - AutoPauseOnExpired: req.AutoPauseOnExpired, - SkipMixedChannelCheck: skipCheck, + result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ + Name: req.Name, + Notes: req.Notes, + Platform: req.Platform, + Type: req.Type, + Credentials: req.Credentials, + Extra: req.Extra, + ProxyID: req.ProxyID, + Concurrency: req.Concurrency, + Priority: req.Priority, + RateMultiplier: req.RateMultiplier, + GroupIDs: req.GroupIDs, + ExpiresAt: req.ExpiresAt, + AutoPauseOnExpired: req.AutoPauseOnExpired, + SkipMixedChannelCheck: skipCheck, + }) + if execErr != nil { + return nil, execErr + } + return h.buildAccountResponseWithRuntime(ctx, account), nil }) if err != nil { // 检查是否为混合渠道错误 var mixedErr *service.MixedChannelError if errors.As(err, &mixedErr) { - // 返回特殊错误码要求确认 + // 创建接口仅返回最小必要字段,详细信息由专门检查接口提供 c.JSON(409, gin.H{ "error": "mixed_channel_warning", "message": mixedErr.Error(), - "details": gin.H{ - "group_id": mixedErr.GroupID, - "group_name": mixedErr.GroupName, - "current_platform": mixedErr.CurrentPlatform, - "other_platform": mixedErr.OtherPlatform, - }, - "require_confirmation": true, }) return } + if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } response.ErrorFrom(c, err) return } - response.Success(c, dto.AccountFromService(account)) + if result != nil && result.Replayed { + c.Header("X-Idempotency-Replayed", "true") + } + response.Success(c, result.Data) } // Update handles updating an account @@ -359,6 +559,8 @@ func (h *AccountHandler) Update(c *gin.Context) { response.BadRequest(c, "rate_multiplier must be >= 0") return } + // base_rpm 输入校验:负值归零,超过 10000 截断 + sanitizeExtraBaseRPM(req.Extra) // 确定是否跳过混合渠道检查 skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk @@ -383,17 +585,10 @@ func (h *AccountHandler) Update(c *gin.Context) { // 检查是否为混合渠道错误 var mixedErr *service.MixedChannelError if errors.As(err, &mixedErr) { - // 返回特殊错误码要求确认 + // 更新接口仅返回最小必要字段,详细信息由专门检查接口提供 c.JSON(409, gin.H{ "error": "mixed_channel_warning", "message": mixedErr.Error(), - "details": gin.H{ - "group_id": mixedErr.GroupID, - "group_name": mixedErr.GroupName, - "current_platform": mixedErr.CurrentPlatform, - "other_platform": mixedErr.OtherPlatform, - }, - "require_confirmation": true, }) return } @@ -402,7 +597,7 @@ func (h *AccountHandler) Update(c *gin.Context) { return } - response.Success(c, dto.AccountFromService(account)) + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } // Delete handles deleting an account @@ -660,7 +855,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) { } } - response.Success(c, dto.AccountFromService(updatedAccount)) + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount)) } // GetStats handles getting account statistics @@ -718,7 +913,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) { } } - response.Success(c, dto.AccountFromService(account)) + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } // BatchCreate handles batch creating accounts @@ -732,61 +927,65 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) { return } - ctx := c.Request.Context() - success := 0 - failed := 0 - results := make([]gin.H, 0, len(req.Accounts)) + executeAdminIdempotentJSON(c, "admin.accounts.batch_create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + success := 0 + failed := 0 + results := make([]gin.H, 0, len(req.Accounts)) - for _, item := range req.Accounts { - if item.RateMultiplier != nil && *item.RateMultiplier < 0 { - failed++ + for _, item := range req.Accounts { + if item.RateMultiplier != nil && *item.RateMultiplier < 0 { + failed++ + results = append(results, gin.H{ + "name": item.Name, + "success": false, + "error": "rate_multiplier must be >= 0", + }) + continue + } + + // base_rpm 输入校验:负值归零,超过 10000 截断 + sanitizeExtraBaseRPM(item.Extra) + + skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk + + account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ + Name: item.Name, + Notes: item.Notes, + Platform: item.Platform, + Type: item.Type, + Credentials: item.Credentials, + Extra: item.Extra, + ProxyID: item.ProxyID, + Concurrency: item.Concurrency, + Priority: item.Priority, + RateMultiplier: item.RateMultiplier, + GroupIDs: item.GroupIDs, + ExpiresAt: item.ExpiresAt, + AutoPauseOnExpired: item.AutoPauseOnExpired, + SkipMixedChannelCheck: skipCheck, + }) + if err != nil { + failed++ + results = append(results, gin.H{ + "name": item.Name, + "success": false, + "error": err.Error(), + }) + continue + } + success++ results = append(results, gin.H{ "name": item.Name, - "success": false, - "error": "rate_multiplier must be >= 0", + "id": account.ID, + "success": true, }) - continue } - skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk - - account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ - Name: item.Name, - Notes: item.Notes, - Platform: item.Platform, - Type: item.Type, - Credentials: item.Credentials, - Extra: item.Extra, - ProxyID: item.ProxyID, - Concurrency: item.Concurrency, - Priority: item.Priority, - RateMultiplier: item.RateMultiplier, - GroupIDs: item.GroupIDs, - ExpiresAt: item.ExpiresAt, - AutoPauseOnExpired: item.AutoPauseOnExpired, - SkipMixedChannelCheck: skipCheck, - }) - if err != nil { - failed++ - results = append(results, gin.H{ - "name": item.Name, - "success": false, - "error": err.Error(), - }) - continue - } - success++ - results = append(results, gin.H{ - "name": item.Name, - "id": account.ID, - "success": true, - }) - } - - response.Success(c, gin.H{ - "success": success, - "failed": failed, - "results": results, + return gin.H{ + "success": success, + "failed": failed, + "results": results, + }, nil }) } @@ -824,57 +1023,58 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) { } ctx := c.Request.Context() - success := 0 - failed := 0 - results := []gin.H{} + // 阶段一:预验证所有账号存在,收集 credentials + type accountUpdate struct { + ID int64 + Credentials map[string]any + } + updates := make([]accountUpdate, 0, len(req.AccountIDs)) for _, accountID := range req.AccountIDs { - // Get account account, err := h.adminService.GetAccount(ctx, accountID) if err != nil { - failed++ - results = append(results, gin.H{ - "account_id": accountID, - "success": false, - "error": "Account not found", - }) - continue + response.Error(c, 404, fmt.Sprintf("Account %d not found", accountID)) + return } - - // Update credentials field if account.Credentials == nil { account.Credentials = make(map[string]any) } - account.Credentials[req.Field] = req.Value + updates = append(updates, accountUpdate{ID: accountID, Credentials: account.Credentials}) + } - // Update account - updateInput := &service.UpdateAccountInput{ - Credentials: account.Credentials, - } - - _, err = h.adminService.UpdateAccount(ctx, accountID, updateInput) - if err != nil { + // 阶段二:依次更新,返回每个账号的成功/失败明细,便于调用方重试 + success := 0 + failed := 0 + successIDs := make([]int64, 0, len(updates)) + failedIDs := make([]int64, 0, len(updates)) + results := make([]gin.H, 0, len(updates)) + for _, u := range updates { + updateInput := &service.UpdateAccountInput{Credentials: u.Credentials} + if _, err := h.adminService.UpdateAccount(ctx, u.ID, updateInput); err != nil { failed++ + failedIDs = append(failedIDs, u.ID) results = append(results, gin.H{ - "account_id": accountID, + "account_id": u.ID, "success": false, "error": err.Error(), }) continue } - success++ + successIDs = append(successIDs, u.ID) results = append(results, gin.H{ - "account_id": accountID, + "account_id": u.ID, "success": true, }) } response.Success(c, gin.H{ - "success": success, - "failed": failed, - "results": results, + "success": success, + "failed": failed, + "success_ids": successIDs, + "failed_ids": failedIDs, + "results": results, }) } @@ -890,6 +1090,8 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { response.BadRequest(c, "rate_multiplier must be >= 0") return } + // base_rpm 输入校验:负值归零,超过 10000 截断 + sanitizeExtraBaseRPM(req.Extra) // 确定是否跳过混合渠道检查 skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk @@ -925,6 +1127,14 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { SkipMixedChannelCheck: skipCheck, }) if err != nil { + var mixedErr *service.MixedChannelError + if errors.As(err, &mixedErr) { + c.JSON(409, gin.H{ + "error": "mixed_channel_warning", + "message": mixedErr.Error(), + }) + return + } response.ErrorFrom(c, err) return } @@ -1109,7 +1319,13 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) { return } - response.Success(c, gin.H{"message": "Rate limit cleared successfully"}) + account, err := h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } // GetTempUnschedulable handles getting temporary unschedulable status @@ -1173,6 +1389,57 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) { response.Success(c, stats) } +// BatchTodayStatsRequest 批量今日统计请求体。 +type BatchTodayStatsRequest struct { + AccountIDs []int64 `json:"account_ids" binding:"required"` +} + +// GetBatchTodayStats 批量获取多个账号的今日统计。 +// POST /api/v1/admin/accounts/today-stats/batch +func (h *AccountHandler) GetBatchTodayStats(c *gin.Context) { + var req BatchTodayStatsRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + accountIDs := normalizeInt64IDList(req.AccountIDs) + if len(accountIDs) == 0 { + response.Success(c, gin.H{"stats": map[string]any{}}) + return + } + + cacheKey := buildAccountTodayStatsBatchCacheKey(accountIDs) + if cached, ok := accountTodayStatsBatchCache.Get(cacheKey); ok { + if cached.ETag != "" { + c.Header("ETag", cached.ETag) + c.Header("Vary", "If-None-Match") + if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) { + c.Status(http.StatusNotModified) + return + } + } + c.Header("X-Snapshot-Cache", "hit") + response.Success(c, cached.Payload) + return + } + + stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), accountIDs) + if err != nil { + response.ErrorFrom(c, err) + return + } + + payload := gin.H{"stats": stats} + cached := accountTodayStatsBatchCache.Set(cacheKey, payload) + if cached.ETag != "" { + c.Header("ETag", cached.ETag) + c.Header("Vary", "If-None-Match") + } + c.Header("X-Snapshot-Cache", "miss") + response.Success(c, payload) +} + // SetSchedulableRequest represents the request body for setting schedulable status type SetSchedulableRequest struct { Schedulable bool `json:"schedulable"` @@ -1199,7 +1466,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) { return } - response.Success(c, dto.AccountFromService(account)) + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } // GetAvailableModels handles getting available models for an account @@ -1296,32 +1563,14 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { // Handle Antigravity accounts: return Claude + Gemini models if account.Platform == service.PlatformAntigravity { - // Antigravity 支持 Claude 和部分 Gemini 模型 - type UnifiedModel struct { - ID string `json:"id"` - Type string `json:"type"` - DisplayName string `json:"display_name"` - } + // 直接复用 antigravity.DefaultModels(),与 /v1/models 端点保持同步 + response.Success(c, antigravity.DefaultModels()) + return + } - var models []UnifiedModel - - // 添加 Claude 模型 - for _, m := range claude.DefaultModels { - models = append(models, UnifiedModel{ - ID: m.ID, - Type: m.Type, - DisplayName: m.DisplayName, - }) - } - - // 添加 Gemini 3 系列模型用于测试 - geminiTestModels := []UnifiedModel{ - {ID: "gemini-3-flash", Type: "model", DisplayName: "Gemini 3 Flash"}, - {ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview"}, - } - models = append(models, geminiTestModels...) - - response.Success(c, models) + // Handle Sora accounts + if account.Platform == service.PlatformSora { + response.Success(c, service.DefaultSoraModels(nil)) return } @@ -1532,3 +1781,22 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { func (h *AccountHandler) GetAntigravityDefaultModelMapping(c *gin.Context) { response.Success(c, domain.DefaultAntigravityModelMapping) } + +// sanitizeExtraBaseRPM 对 extra map 中的 base_rpm 值进行范围校验和归一化。 +// 负值归零,超过 10000 截断为 10000。extra 为 nil 或不含 base_rpm 时无操作。 +func sanitizeExtraBaseRPM(extra map[string]any) { + if extra == nil { + return + } + raw, ok := extra["base_rpm"] + if !ok { + return + } + v := service.ParseExtraInt(raw) + if v < 0 { + v = 0 + } else if v > 10000 { + v = 10000 + } + extra["base_rpm"] = v +} diff --git a/backend/internal/handler/admin/account_handler_mixed_channel_test.go b/backend/internal/handler/admin/account_handler_mixed_channel_test.go new file mode 100644 index 00000000..24ec5bcf --- /dev/null +++ b/backend/internal/handler/admin/account_handler_mixed_channel_test.go @@ -0,0 +1,198 @@ +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel) + router.POST("/api/v1/admin/accounts", accountHandler.Create) + router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update) + router.POST("/api/v1/admin/accounts/bulk-update", accountHandler.BulkUpdate) + return router +} + +func TestAccountHandlerCheckMixedChannelNoRisk(t *testing.T) { + adminSvc := newStubAdminService() + router := setupAccountMixedChannelRouter(adminSvc) + + body, _ := json.Marshal(map[string]any{ + "platform": "antigravity", + "group_ids": []int64{27}, + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, float64(0), resp["code"]) + data, ok := resp["data"].(map[string]any) + require.True(t, ok) + require.Equal(t, false, data["has_risk"]) + require.Equal(t, int64(0), adminSvc.lastMixedCheck.accountID) + require.Equal(t, "antigravity", adminSvc.lastMixedCheck.platform) + require.Equal(t, []int64{27}, adminSvc.lastMixedCheck.groupIDs) +} + +func TestAccountHandlerCheckMixedChannelWithRisk(t *testing.T) { + adminSvc := newStubAdminService() + adminSvc.checkMixedErr = &service.MixedChannelError{ + GroupID: 27, + GroupName: "claude-max", + CurrentPlatform: "Antigravity", + OtherPlatform: "Anthropic", + } + router := setupAccountMixedChannelRouter(adminSvc) + + body, _ := json.Marshal(map[string]any{ + "platform": "antigravity", + "group_ids": []int64{27}, + "account_id": 99, + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, float64(0), resp["code"]) + data, ok := resp["data"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, data["has_risk"]) + require.Equal(t, "mixed_channel_warning", data["error"]) + details, ok := data["details"].(map[string]any) + require.True(t, ok) + require.Equal(t, float64(27), details["group_id"]) + require.Equal(t, "claude-max", details["group_name"]) + require.Equal(t, "Antigravity", details["current_platform"]) + require.Equal(t, "Anthropic", details["other_platform"]) + require.Equal(t, int64(99), adminSvc.lastMixedCheck.accountID) +} + +func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T) { + adminSvc := newStubAdminService() + adminSvc.createAccountErr = &service.MixedChannelError{ + GroupID: 27, + GroupName: "claude-max", + CurrentPlatform: "Antigravity", + OtherPlatform: "Anthropic", + } + router := setupAccountMixedChannelRouter(adminSvc) + + body, _ := json.Marshal(map[string]any{ + "name": "ag-oauth-1", + "platform": "antigravity", + "type": "oauth", + "credentials": map[string]any{"refresh_token": "rt"}, + "group_ids": []int64{27}, + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusConflict, rec.Code) + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, "mixed_channel_warning", resp["error"]) + require.Contains(t, resp["message"], "mixed_channel_warning") + _, hasDetails := resp["details"] + _, hasRequireConfirmation := resp["require_confirmation"] + require.False(t, hasDetails) + require.False(t, hasRequireConfirmation) +} + +func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T) { + adminSvc := newStubAdminService() + adminSvc.updateAccountErr = &service.MixedChannelError{ + GroupID: 27, + GroupName: "claude-max", + CurrentPlatform: "Antigravity", + OtherPlatform: "Anthropic", + } + router := setupAccountMixedChannelRouter(adminSvc) + + body, _ := json.Marshal(map[string]any{ + "group_ids": []int64{27}, + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/accounts/3", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusConflict, rec.Code) + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, "mixed_channel_warning", resp["error"]) + require.Contains(t, resp["message"], "mixed_channel_warning") + _, hasDetails := resp["details"] + _, hasRequireConfirmation := resp["require_confirmation"] + require.False(t, hasDetails) + require.False(t, hasRequireConfirmation) +} + +func TestAccountHandlerBulkUpdateMixedChannelConflict(t *testing.T) { + adminSvc := newStubAdminService() + adminSvc.bulkUpdateAccountErr = &service.MixedChannelError{ + GroupID: 27, + GroupName: "claude-max", + CurrentPlatform: "Antigravity", + OtherPlatform: "Anthropic", + } + router := setupAccountMixedChannelRouter(adminSvc) + + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1, 2, 3}, + "group_ids": []int64{27}, + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusConflict, rec.Code) + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, "mixed_channel_warning", resp["error"]) + require.Contains(t, resp["message"], "claude-max") +} + +func TestAccountHandlerBulkUpdateMixedChannelConfirmSkips(t *testing.T) { + adminSvc := newStubAdminService() + router := setupAccountMixedChannelRouter(adminSvc) + + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1, 2}, + "group_ids": []int64{27}, + "confirm_mixed_channel_risk": true, + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, float64(0), resp["code"]) + data, ok := resp["data"].(map[string]any) + require.True(t, ok) + require.Equal(t, float64(2), data["success"]) + require.Equal(t, float64(0), data["failed"]) +} diff --git a/backend/internal/handler/admin/account_handler_passthrough_test.go b/backend/internal/handler/admin/account_handler_passthrough_test.go new file mode 100644 index 00000000..d86501c0 --- /dev/null +++ b/backend/internal/handler/admin/account_handler_passthrough_test.go @@ -0,0 +1,67 @@ +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testing.T) { + gin.SetMode(gin.TestMode) + + adminSvc := newStubAdminService() + handler := NewAccountHandler( + adminSvc, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + ) + + router := gin.New() + router.POST("/api/v1/admin/accounts", handler.Create) + + body := map[string]any{ + "name": "anthropic-key-1", + "platform": "anthropic", + "type": "apikey", + "credentials": map[string]any{ + "api_key": "sk-ant-xxx", + "base_url": "https://api.anthropic.com", + }, + "extra": map[string]any{ + "anthropic_passthrough": true, + }, + "concurrency": 1, + "priority": 1, + } + raw, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(raw)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Len(t, adminSvc.createdAccounts, 1) + + created := adminSvc.createdAccounts[0] + require.Equal(t, "anthropic", created.Platform) + require.Equal(t, "apikey", created.Type) + require.NotNil(t, created.Extra) + require.Equal(t, true, created.Extra["anthropic_passthrough"]) +} diff --git a/backend/internal/handler/admin/account_today_stats_cache.go b/backend/internal/handler/admin/account_today_stats_cache.go new file mode 100644 index 00000000..61922f70 --- /dev/null +++ b/backend/internal/handler/admin/account_today_stats_cache.go @@ -0,0 +1,25 @@ +package admin + +import ( + "strconv" + "strings" + "time" +) + +var accountTodayStatsBatchCache = newSnapshotCache(30 * time.Second) + +func buildAccountTodayStatsBatchCacheKey(accountIDs []int64) string { + if len(accountIDs) == 0 { + return "accounts_today_stats_empty" + } + var b strings.Builder + b.Grow(len(accountIDs) * 6) + _, _ = b.WriteString("accounts_today_stats:") + for i, id := range accountIDs { + if i > 0 { + _ = b.WriteByte(',') + } + _, _ = b.WriteString(strconv.FormatInt(id, 10)) + } + return b.String() +} diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go index 20a25222..4de10d3e 100644 --- a/backend/internal/handler/admin/admin_basic_handlers_test.go +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -19,7 +19,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { userHandler := NewUserHandler(adminSvc, nil) groupHandler := NewGroupHandler(adminSvc) proxyHandler := NewProxyHandler(adminSvc) - redeemHandler := NewRedeemHandler(adminSvc) + redeemHandler := NewRedeemHandler(adminSvc, nil) router.GET("/api/v1/admin/users", userHandler.List) router.GET("/api/v1/admin/users/:id", userHandler.GetByID) @@ -47,6 +47,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete) router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete) router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test) + router.POST("/api/v1/admin/proxies/:id/quality-check", proxyHandler.CheckQuality) router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats) router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts) @@ -208,6 +209,11 @@ func TestProxyHandlerEndpoints(t *testing.T) { router.ServeHTTP(rec, req) require.Equal(t, http.StatusOK, rec.Code) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/quality-check", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + rec = httptest.NewRecorder() req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil) router.ServeHTTP(rec, req) diff --git a/backend/internal/handler/admin/admin_helpers_test.go b/backend/internal/handler/admin/admin_helpers_test.go index 863c755c..3833d32e 100644 --- a/backend/internal/handler/admin/admin_helpers_test.go +++ b/backend/internal/handler/admin/admin_helpers_test.go @@ -58,6 +58,96 @@ func TestParseOpsDuration(t *testing.T) { require.False(t, ok) } +func TestParseOpsOpenAITokenStatsDuration(t *testing.T) { + tests := []struct { + input string + want time.Duration + ok bool + }{ + {input: "30m", want: 30 * time.Minute, ok: true}, + {input: "1h", want: time.Hour, ok: true}, + {input: "1d", want: 24 * time.Hour, ok: true}, + {input: "15d", want: 15 * 24 * time.Hour, ok: true}, + {input: "30d", want: 30 * 24 * time.Hour, ok: true}, + {input: "7d", want: 0, ok: false}, + } + + for _, tt := range tests { + got, ok := parseOpsOpenAITokenStatsDuration(tt.input) + require.Equal(t, tt.ok, ok, "input=%s", tt.input) + require.Equal(t, tt.want, got, "input=%s", tt.input) + } +} + +func TestParseOpsOpenAITokenStatsFilter_Defaults(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + before := time.Now().UTC() + filter, err := parseOpsOpenAITokenStatsFilter(c) + after := time.Now().UTC() + + require.NoError(t, err) + require.NotNil(t, filter) + require.Equal(t, "30d", filter.TimeRange) + require.Equal(t, 1, filter.Page) + require.Equal(t, 20, filter.PageSize) + require.Equal(t, 0, filter.TopN) + require.Nil(t, filter.GroupID) + require.Equal(t, "", filter.Platform) + require.True(t, filter.StartTime.Before(filter.EndTime)) + require.WithinDuration(t, before.Add(-30*24*time.Hour), filter.StartTime, 2*time.Second) + require.WithinDuration(t, after, filter.EndTime, 2*time.Second) +} + +func TestParseOpsOpenAITokenStatsFilter_WithTopN(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest( + http.MethodGet, + "/?time_range=1h&platform=openai&group_id=12&top_n=50", + nil, + ) + + filter, err := parseOpsOpenAITokenStatsFilter(c) + require.NoError(t, err) + require.Equal(t, "1h", filter.TimeRange) + require.Equal(t, "openai", filter.Platform) + require.NotNil(t, filter.GroupID) + require.Equal(t, int64(12), *filter.GroupID) + require.Equal(t, 50, filter.TopN) + require.Equal(t, 0, filter.Page) + require.Equal(t, 0, filter.PageSize) +} + +func TestParseOpsOpenAITokenStatsFilter_InvalidParams(t *testing.T) { + tests := []string{ + "/?time_range=7d", + "/?group_id=0", + "/?group_id=abc", + "/?top_n=0", + "/?top_n=101", + "/?top_n=10&page=1", + "/?top_n=10&page_size=20", + "/?page=0", + "/?page_size=0", + "/?page_size=101", + } + + gin.SetMode(gin.TestMode) + for _, rawURL := range tests { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, rawURL, nil) + + _, err := parseOpsOpenAITokenStatsFilter(c) + require.Error(t, err, "url=%s", rawURL) + } +} + func TestParseOpsTimeRange(t *testing.T) { gin.SetMode(gin.TestMode) w := httptest.NewRecorder() diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index d44c99ea..f3b99ddb 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -10,19 +10,28 @@ import ( ) type stubAdminService struct { - users []service.User - apiKeys []service.APIKey - groups []service.Group - accounts []service.Account - proxies []service.Proxy - proxyCounts []service.ProxyWithAccountCount - redeems []service.RedeemCode - createdAccounts []*service.CreateAccountInput - createdProxies []*service.CreateProxyInput - updatedProxyIDs []int64 - updatedProxies []*service.UpdateProxyInput - testedProxyIDs []int64 - mu sync.Mutex + users []service.User + apiKeys []service.APIKey + groups []service.Group + accounts []service.Account + proxies []service.Proxy + proxyCounts []service.ProxyWithAccountCount + redeems []service.RedeemCode + createdAccounts []*service.CreateAccountInput + createdProxies []*service.CreateProxyInput + updatedProxyIDs []int64 + updatedProxies []*service.UpdateProxyInput + testedProxyIDs []int64 + createAccountErr error + updateAccountErr error + bulkUpdateAccountErr error + checkMixedErr error + lastMixedCheck struct { + accountID int64 + platform string + groupIDs []int64 + } + mu sync.Mutex } func newStubAdminService() *stubAdminService { @@ -188,11 +197,17 @@ func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.Cre s.mu.Lock() s.createdAccounts = append(s.createdAccounts, input) s.mu.Unlock() + if s.createAccountErr != nil { + return nil, s.createAccountErr + } account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive} return &account, nil } func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) { + if s.updateAccountErr != nil { + return nil, s.updateAccountErr + } account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive} return &account, nil } @@ -221,7 +236,17 @@ func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64, } func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) { - return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil + if s.bulkUpdateAccountErr != nil { + return nil, s.bulkUpdateAccountErr + } + return &service.BulkUpdateAccountsResult{Success: len(input.AccountIDs), Failed: 0, SuccessIDs: input.AccountIDs}, nil +} + +func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error { + s.lastMixedCheck.accountID = currentAccountID + s.lastMixedCheck.platform = currentAccountPlatform + s.lastMixedCheck.groupIDs = append([]int64(nil), groupIDs...) + return s.checkMixedErr } func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) { @@ -327,6 +352,27 @@ func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.Pr return &service.ProxyTestResult{Success: true, Message: "ok"}, nil } +func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*service.ProxyQualityCheckResult, error) { + return &service.ProxyQualityCheckResult{ + ProxyID: id, + Score: 95, + Grade: "A", + Summary: "通过 5 项,告警 0 项,失败 0 项,挑战 0 项", + PassedCount: 5, + WarnCount: 0, + FailedCount: 0, + ChallengeCount: 0, + CheckedAt: time.Now().Unix(), + Items: []service.ProxyQualityCheckItem{ + {Target: "base_connectivity", Status: "pass", Message: "ok"}, + {Target: "openai", Status: "pass", HTTPStatus: 401}, + {Target: "anthropic", Status: "pass", HTTPStatus: 401}, + {Target: "gemini", Status: "pass", HTTPStatus: 200}, + {Target: "sora", Status: "pass", HTTPStatus: 401}, + }, + }, nil +} + func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) { return s.redeems, int64(len(s.redeems)), nil } @@ -361,5 +407,23 @@ func (s *stubAdminService) UpdateGroupSortOrders(ctx context.Context, updates [] return nil } +func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) { + for i := range s.apiKeys { + if s.apiKeys[i].ID == keyID { + k := s.apiKeys[i] + if groupID != nil { + if *groupID == 0 { + k.GroupID = nil + } else { + gid := *groupID + k.GroupID = &gid + } + } + return &service.AdminUpdateAPIKeyGroupIDResult{APIKey: &k}, nil + } + } + return nil, service.ErrAPIKeyNotFound +} + // Ensure stub implements interface. var _ service.AdminService = (*stubAdminService)(nil) diff --git a/backend/internal/handler/admin/apikey_handler.go b/backend/internal/handler/admin/apikey_handler.go new file mode 100644 index 00000000..8dd245a4 --- /dev/null +++ b/backend/internal/handler/admin/apikey_handler.go @@ -0,0 +1,63 @@ +package admin + +import ( + "strconv" + + "github.com/Wei-Shaw/sub2api/internal/handler/dto" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// AdminAPIKeyHandler handles admin API key management +type AdminAPIKeyHandler struct { + adminService service.AdminService +} + +// NewAdminAPIKeyHandler creates a new admin API key handler +func NewAdminAPIKeyHandler(adminService service.AdminService) *AdminAPIKeyHandler { + return &AdminAPIKeyHandler{ + adminService: adminService, + } +} + +// AdminUpdateAPIKeyGroupRequest represents the request to update an API key's group +type AdminUpdateAPIKeyGroupRequest struct { + GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组 +} + +// UpdateGroup handles updating an API key's group binding +// PUT /api/v1/admin/api-keys/:id +func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) { + keyID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid API key ID") + return + } + + var req AdminUpdateAPIKeyGroupRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + resp := struct { + APIKey *dto.APIKey `json:"api_key"` + AutoGrantedGroupAccess bool `json:"auto_granted_group_access"` + GrantedGroupID *int64 `json:"granted_group_id,omitempty"` + GrantedGroupName string `json:"granted_group_name,omitempty"` + }{ + APIKey: dto.APIKeyFromService(result.APIKey), + AutoGrantedGroupAccess: result.AutoGrantedGroupAccess, + GrantedGroupID: result.GrantedGroupID, + GrantedGroupName: result.GrantedGroupName, + } + response.Success(c, resp) +} diff --git a/backend/internal/handler/admin/apikey_handler_test.go b/backend/internal/handler/admin/apikey_handler_test.go new file mode 100644 index 00000000..bf128b18 --- /dev/null +++ b/backend/internal/handler/admin/apikey_handler_test.go @@ -0,0 +1,202 @@ +package admin + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func setupAPIKeyHandler(adminSvc service.AdminService) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + h := NewAdminAPIKeyHandler(adminSvc) + router.PUT("/api/v1/admin/api-keys/:id", h.UpdateGroup) + return router +} + +func TestAdminAPIKeyHandler_UpdateGroup_InvalidID(t *testing.T) { + router := setupAPIKeyHandler(newStubAdminService()) + body := `{"group_id": 2}` + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/abc", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "Invalid API key ID") +} + +func TestAdminAPIKeyHandler_UpdateGroup_InvalidJSON(t *testing.T) { + router := setupAPIKeyHandler(newStubAdminService()) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{bad json`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "Invalid request") +} + +func TestAdminAPIKeyHandler_UpdateGroup_KeyNotFound(t *testing.T) { + router := setupAPIKeyHandler(newStubAdminService()) + body := `{"group_id": 2}` + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/999", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + // ErrAPIKeyNotFound maps to 404 + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestAdminAPIKeyHandler_UpdateGroup_BindGroup(t *testing.T) { + router := setupAPIKeyHandler(newStubAdminService()) + body := `{"group_id": 2}` + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp struct { + Code int `json:"code"` + Data json.RawMessage `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + + var data struct { + APIKey struct { + ID int64 `json:"id"` + GroupID *int64 `json:"group_id"` + } `json:"api_key"` + AutoGrantedGroupAccess bool `json:"auto_granted_group_access"` + } + require.NoError(t, json.Unmarshal(resp.Data, &data)) + require.Equal(t, int64(10), data.APIKey.ID) + require.NotNil(t, data.APIKey.GroupID) + require.Equal(t, int64(2), *data.APIKey.GroupID) +} + +func TestAdminAPIKeyHandler_UpdateGroup_Unbind(t *testing.T) { + svc := newStubAdminService() + gid := int64(2) + svc.apiKeys[0].GroupID = &gid + router := setupAPIKeyHandler(svc) + body := `{"group_id": 0}` + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp struct { + Data struct { + APIKey struct { + GroupID *int64 `json:"group_id"` + } `json:"api_key"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Nil(t, resp.Data.APIKey.GroupID) +} + +func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) { + svc := &failingUpdateGroupService{ + stubAdminService: newStubAdminService(), + err: errors.New("internal failure"), + } + router := setupAPIKeyHandler(svc) + body := `{"group_id": 2}` + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// H2: empty body → group_id is nil → no-op, returns original key +func TestAdminAPIKeyHandler_UpdateGroup_EmptyBody_NoChange(t *testing.T) { + router := setupAPIKeyHandler(newStubAdminService()) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + APIKey struct { + ID int64 `json:"id"` + } `json:"api_key"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, int64(10), resp.Data.APIKey.ID) +} + +// M2: service returns GROUP_NOT_ACTIVE → handler maps to 400 +func TestAdminAPIKeyHandler_UpdateGroup_GroupNotActive(t *testing.T) { + svc := &failingUpdateGroupService{ + stubAdminService: newStubAdminService(), + err: infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active"), + } + router := setupAPIKeyHandler(svc) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": 5}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "GROUP_NOT_ACTIVE") +} + +// M2: service returns INVALID_GROUP_ID → handler maps to 400 +func TestAdminAPIKeyHandler_UpdateGroup_NegativeGroupID(t *testing.T) { + svc := &failingUpdateGroupService{ + stubAdminService: newStubAdminService(), + err: infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative"), + } + router := setupAPIKeyHandler(svc) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": -5}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "INVALID_GROUP_ID") +} + +// failingUpdateGroupService overrides AdminUpdateAPIKeyGroupID to return an error. +type failingUpdateGroupService struct { + *stubAdminService + err error +} + +func (f *failingUpdateGroupService) AdminUpdateAPIKeyGroupID(_ context.Context, _ int64, _ *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) { + return nil, f.err +} diff --git a/backend/internal/handler/admin/batch_update_credentials_test.go b/backend/internal/handler/admin/batch_update_credentials_test.go new file mode 100644 index 00000000..0b1b6691 --- /dev/null +++ b/backend/internal/handler/admin/batch_update_credentials_test.go @@ -0,0 +1,208 @@ +//go:build unit + +package admin + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// failingAdminService 嵌入 stubAdminService,可配置 UpdateAccount 在指定 ID 时失败。 +type failingAdminService struct { + *stubAdminService + failOnAccountID int64 + updateCallCount atomic.Int64 +} + +func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) { + f.updateCallCount.Add(1) + if id == f.failOnAccountID { + return nil, errors.New("database error") + } + return f.stubAdminService.UpdateAccount(ctx, id, input) +} + +func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) { + gin.SetMode(gin.TestMode) + router := gin.New() + handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials) + return router, handler +} + +func TestBatchUpdateCredentials_AllSuccess(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + body, _ := json.Marshal(BatchUpdateCredentialsRequest{ + AccountIDs: []int64{1, 2, 3}, + Field: "account_uuid", + Value: "test-uuid", + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, "全部成功时应返回 200") + require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount") +} + +func TestBatchUpdateCredentials_PartialFailure(t *testing.T) { + // 让第 2 个账号(ID=2)更新时失败 + svc := &failingAdminService{ + stubAdminService: newStubAdminService(), + failOnAccountID: 2, + } + router, _ := setupAccountHandlerWithService(svc) + + body, _ := json.Marshal(BatchUpdateCredentialsRequest{ + AccountIDs: []int64{1, 2, 3}, + Field: "org_uuid", + Value: "test-org", + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + // 实现采用"部分成功"模式:总是返回 200 + 成功/失败明细 + require.Equal(t, http.StatusOK, w.Code, "批量更新返回 200 + 成功/失败明细") + + var resp map[string]any + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + data := resp["data"].(map[string]any) + require.Equal(t, float64(2), data["success"], "应有 2 个成功") + require.Equal(t, float64(1), data["failed"], "应有 1 个失败") + + // 所有 3 个账号都会被尝试更新(非 fail-fast) + require.Equal(t, int64(3), svc.updateCallCount.Load(), + "应调用 3 次 UpdateAccount(逐个尝试,失败后继续)") +} + +func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) { + // GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub + svc := &getAccountFailingService{ + stubAdminService: newStubAdminService(), + failOnAccountID: 1, + } + router, _ := setupAccountHandlerWithService(svc) + + body, _ := json.Marshal(BatchUpdateCredentialsRequest{ + AccountIDs: []int64{1, 2, 3}, + Field: "account_uuid", + Value: "test", + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusNotFound, w.Code, "第一阶段验证失败应返回 404") +} + +// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。 +type getAccountFailingService struct { + *stubAdminService + failOnAccountID int64 +} + +func (f *getAccountFailingService) GetAccount(ctx context.Context, id int64) (*service.Account, error) { + if id == f.failOnAccountID { + return nil, errors.New("not found") + } + return f.stubAdminService.GetAccount(ctx, id) +} + +func TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + // intercept_warmup_requests 传入非 bool 类型(string),应返回 400 + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "field": "intercept_warmup_requests", + "value": "not-a-bool", + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusBadRequest, w.Code, + "intercept_warmup_requests 传入非 bool 值应返回 400") +} + +func TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "field": "intercept_warmup_requests", + "value": true, + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, + "intercept_warmup_requests 传入合法 bool 值应返回 200") +} + +func TestBatchUpdateCredentials_AccountUUID_NonString(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + // account_uuid 传入非 string 类型(number),应返回 400 + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "field": "account_uuid", + "value": 12345, + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusBadRequest, w.Code, + "account_uuid 传入非 string 值应返回 400") +} + +func TestBatchUpdateCredentials_AccountUUID_NullValue(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + // account_uuid 传入 null(设置为空),应正常通过 + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "field": "account_uuid", + "value": nil, + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, + "account_uuid 传入 null 应返回 200") +} diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index 18365186..b0da6c5e 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -1,8 +1,10 @@ package admin import ( + "encoding/json" "errors" "strconv" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -186,7 +188,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) { // GetUsageTrend handles getting usage trend data // GET /api/v1/admin/dashboard/trend -// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type +// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, request_type, stream, billing_type func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { startTime, endTime := parseTimeRange(c) granularity := c.DefaultQuery("granularity", "day") @@ -194,6 +196,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { // Parse optional filter params var userID, apiKeyID, accountID, groupID int64 var model string + var requestType *int16 var stream *bool var billingType *int8 @@ -220,9 +223,20 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { if modelStr := c.Query("model"); modelStr != "" { model = modelStr } - if streamStr := c.Query("stream"); streamStr != "" { + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { if streamVal, err := strconv.ParseBool(streamStr); err == nil { stream = &streamVal + } else { + response.BadRequest(c, "Invalid stream value, use true or false") + return } } if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { @@ -235,7 +249,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { } } - trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType) + trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get usage trend") return @@ -251,12 +265,13 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { // GetModelStats handles getting model usage statistics // GET /api/v1/admin/dashboard/models -// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type +// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type func (h *DashboardHandler) GetModelStats(c *gin.Context) { startTime, endTime := parseTimeRange(c) // Parse optional filter params var userID, apiKeyID, accountID, groupID int64 + var requestType *int16 var stream *bool var billingType *int8 @@ -280,9 +295,20 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { groupID = id } } - if streamStr := c.Query("stream"); streamStr != "" { + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { if streamVal, err := strconv.ParseBool(streamStr); err == nil { stream = &streamVal + } else { + response.BadRequest(c, "Invalid stream value, use true or false") + return } } if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { @@ -295,7 +321,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { } } - stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType) + stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get model statistics") return @@ -308,6 +334,76 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { }) } +// GetGroupStats handles getting group usage statistics +// GET /api/v1/admin/dashboard/groups +// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type +func (h *DashboardHandler) GetGroupStats(c *gin.Context) { + startTime, endTime := parseTimeRange(c) + + var userID, apiKeyID, accountID, groupID int64 + var requestType *int16 + var stream *bool + var billingType *int8 + + if userIDStr := c.Query("user_id"); userIDStr != "" { + if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil { + userID = id + } + } + if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" { + if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil { + apiKeyID = id + } + } + if accountIDStr := c.Query("account_id"); accountIDStr != "" { + if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil { + accountID = id + } + } + if groupIDStr := c.Query("group_id"); groupIDStr != "" { + if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil { + groupID = id + } + } + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { + if streamVal, err := strconv.ParseBool(streamStr); err == nil { + stream = &streamVal + } else { + response.BadRequest(c, "Invalid stream value, use true or false") + return + } + } + if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { + if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil { + bt := int8(v) + billingType = &bt + } else { + response.BadRequest(c, "Invalid billing_type") + return + } + } + + stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + if err != nil { + response.Error(c, 500, "Failed to get group statistics") + return + } + + response.Success(c, gin.H{ + "groups": stats, + "start_date": startTime.Format("2006-01-02"), + "end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"), + }) +} + // GetAPIKeyUsageTrend handles getting API key usage trend data // GET /api/v1/admin/dashboard/api-keys-trend // Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 5) @@ -365,6 +461,9 @@ type BatchUsersUsageRequest struct { UserIDs []int64 `json:"user_ids" binding:"required"` } +var dashboardBatchUsersUsageCache = newSnapshotCache(30 * time.Second) +var dashboardBatchAPIKeysUsageCache = newSnapshotCache(30 * time.Second) + // GetBatchUsersUsage handles getting usage stats for multiple users // POST /api/v1/admin/dashboard/users-usage func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) { @@ -374,18 +473,34 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) { return } - if len(req.UserIDs) == 0 { + userIDs := normalizeInt64IDList(req.UserIDs) + if len(userIDs) == 0 { response.Success(c, gin.H{"stats": map[string]any{}}) return } - stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs) + keyRaw, _ := json.Marshal(struct { + UserIDs []int64 `json:"user_ids"` + }{ + UserIDs: userIDs, + }) + cacheKey := string(keyRaw) + if cached, ok := dashboardBatchUsersUsageCache.Get(cacheKey); ok { + c.Header("X-Snapshot-Cache", "hit") + response.Success(c, cached.Payload) + return + } + + stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), userIDs, time.Time{}, time.Time{}) if err != nil { response.Error(c, 500, "Failed to get user usage stats") return } - response.Success(c, gin.H{"stats": stats}) + payload := gin.H{"stats": stats} + dashboardBatchUsersUsageCache.Set(cacheKey, payload) + c.Header("X-Snapshot-Cache", "miss") + response.Success(c, payload) } // BatchAPIKeysUsageRequest represents the request body for batch api key usage stats @@ -402,16 +517,32 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) { return } - if len(req.APIKeyIDs) == 0 { + apiKeyIDs := normalizeInt64IDList(req.APIKeyIDs) + if len(apiKeyIDs) == 0 { response.Success(c, gin.H{"stats": map[string]any{}}) return } - stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs) + keyRaw, _ := json.Marshal(struct { + APIKeyIDs []int64 `json:"api_key_ids"` + }{ + APIKeyIDs: apiKeyIDs, + }) + cacheKey := string(keyRaw) + if cached, ok := dashboardBatchAPIKeysUsageCache.Get(cacheKey); ok { + c.Header("X-Snapshot-Cache", "hit") + response.Success(c, cached.Payload) + return + } + + stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), apiKeyIDs, time.Time{}, time.Time{}) if err != nil { response.Error(c, 500, "Failed to get API key usage stats") return } - response.Success(c, gin.H{"stats": stats}) + payload := gin.H{"stats": stats} + dashboardBatchAPIKeysUsageCache.Set(cacheKey, payload) + c.Header("X-Snapshot-Cache", "miss") + response.Success(c, payload) } diff --git a/backend/internal/handler/admin/dashboard_handler_request_type_test.go b/backend/internal/handler/admin/dashboard_handler_request_type_test.go new file mode 100644 index 00000000..72af6b45 --- /dev/null +++ b/backend/internal/handler/admin/dashboard_handler_request_type_test.go @@ -0,0 +1,132 @@ +package admin + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type dashboardUsageRepoCapture struct { + service.UsageLogRepository + trendRequestType *int16 + trendStream *bool + modelRequestType *int16 + modelStream *bool +} + +func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters( + ctx context.Context, + startTime, endTime time.Time, + granularity string, + userID, apiKeyID, accountID, groupID int64, + model string, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.TrendDataPoint, error) { + s.trendRequestType = requestType + s.trendStream = stream + return []usagestats.TrendDataPoint{}, nil +} + +func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters( + ctx context.Context, + startTime, endTime time.Time, + userID, apiKeyID, accountID, groupID int64, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.ModelStat, error) { + s.modelRequestType = requestType + s.modelStream = stream + return []usagestats.ModelStat{}, nil +} + +func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine { + gin.SetMode(gin.TestMode) + dashboardSvc := service.NewDashboardService(repo, nil, nil, nil) + handler := NewDashboardHandler(dashboardSvc, nil) + router := gin.New() + router.GET("/admin/dashboard/trend", handler.GetUsageTrend) + router.GET("/admin/dashboard/models", handler.GetModelStats) + return router +} + +func TestDashboardTrendRequestTypePriority(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=ws_v2&stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.trendRequestType) + require.Equal(t, int16(service.RequestTypeWSV2), *repo.trendRequestType) + require.Nil(t, repo.trendStream) +} + +func TestDashboardTrendInvalidRequestType(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDashboardTrendInvalidStream(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDashboardModelStatsRequestTypePriority(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=sync&stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.modelRequestType) + require.Equal(t, int16(service.RequestTypeSync), *repo.modelRequestType) + require.Nil(t, repo.modelStream) +} + +func TestDashboardModelStatsInvalidRequestType(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDashboardModelStatsInvalidStream(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} diff --git a/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go b/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go new file mode 100644 index 00000000..f6db69f3 --- /dev/null +++ b/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go @@ -0,0 +1,292 @@ +package admin + +import ( + "encoding/json" + "net/http" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +var dashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second) + +type dashboardSnapshotV2Stats struct { + usagestats.DashboardStats + Uptime int64 `json:"uptime"` +} + +type dashboardSnapshotV2Response struct { + GeneratedAt string `json:"generated_at"` + + StartDate string `json:"start_date"` + EndDate string `json:"end_date"` + Granularity string `json:"granularity"` + + Stats *dashboardSnapshotV2Stats `json:"stats,omitempty"` + Trend []usagestats.TrendDataPoint `json:"trend,omitempty"` + Models []usagestats.ModelStat `json:"models,omitempty"` + Groups []usagestats.GroupStat `json:"groups,omitempty"` + UsersTrend []usagestats.UserUsageTrendPoint `json:"users_trend,omitempty"` +} + +type dashboardSnapshotV2Filters struct { + UserID int64 + APIKeyID int64 + AccountID int64 + GroupID int64 + Model string + RequestType *int16 + Stream *bool + BillingType *int8 +} + +type dashboardSnapshotV2CacheKey struct { + StartTime string `json:"start_time"` + EndTime string `json:"end_time"` + Granularity string `json:"granularity"` + UserID int64 `json:"user_id"` + APIKeyID int64 `json:"api_key_id"` + AccountID int64 `json:"account_id"` + GroupID int64 `json:"group_id"` + Model string `json:"model"` + RequestType *int16 `json:"request_type"` + Stream *bool `json:"stream"` + BillingType *int8 `json:"billing_type"` + IncludeStats bool `json:"include_stats"` + IncludeTrend bool `json:"include_trend"` + IncludeModels bool `json:"include_models"` + IncludeGroups bool `json:"include_groups"` + IncludeUsersTrend bool `json:"include_users_trend"` + UsersTrendLimit int `json:"users_trend_limit"` +} + +func (h *DashboardHandler) GetSnapshotV2(c *gin.Context) { + startTime, endTime := parseTimeRange(c) + granularity := strings.TrimSpace(c.DefaultQuery("granularity", "day")) + if granularity != "hour" { + granularity = "day" + } + + includeStats := parseBoolQueryWithDefault(c.Query("include_stats"), true) + includeTrend := parseBoolQueryWithDefault(c.Query("include_trend"), true) + includeModels := parseBoolQueryWithDefault(c.Query("include_model_stats"), true) + includeGroups := parseBoolQueryWithDefault(c.Query("include_group_stats"), false) + includeUsersTrend := parseBoolQueryWithDefault(c.Query("include_users_trend"), false) + usersTrendLimit := 12 + if raw := strings.TrimSpace(c.Query("users_trend_limit")); raw != "" { + if parsed, err := strconv.Atoi(raw); err == nil && parsed > 0 && parsed <= 50 { + usersTrendLimit = parsed + } + } + + filters, err := parseDashboardSnapshotV2Filters(c) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + keyRaw, _ := json.Marshal(dashboardSnapshotV2CacheKey{ + StartTime: startTime.UTC().Format(time.RFC3339), + EndTime: endTime.UTC().Format(time.RFC3339), + Granularity: granularity, + UserID: filters.UserID, + APIKeyID: filters.APIKeyID, + AccountID: filters.AccountID, + GroupID: filters.GroupID, + Model: filters.Model, + RequestType: filters.RequestType, + Stream: filters.Stream, + BillingType: filters.BillingType, + IncludeStats: includeStats, + IncludeTrend: includeTrend, + IncludeModels: includeModels, + IncludeGroups: includeGroups, + IncludeUsersTrend: includeUsersTrend, + UsersTrendLimit: usersTrendLimit, + }) + cacheKey := string(keyRaw) + + if cached, ok := dashboardSnapshotV2Cache.Get(cacheKey); ok { + if cached.ETag != "" { + c.Header("ETag", cached.ETag) + c.Header("Vary", "If-None-Match") + if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) { + c.Status(http.StatusNotModified) + return + } + } + c.Header("X-Snapshot-Cache", "hit") + response.Success(c, cached.Payload) + return + } + + resp := &dashboardSnapshotV2Response{ + GeneratedAt: time.Now().UTC().Format(time.RFC3339), + StartDate: startTime.Format("2006-01-02"), + EndDate: endTime.Add(-24 * time.Hour).Format("2006-01-02"), + Granularity: granularity, + } + + if includeStats { + stats, err := h.dashboardService.GetDashboardStats(c.Request.Context()) + if err != nil { + response.Error(c, 500, "Failed to get dashboard statistics") + return + } + resp.Stats = &dashboardSnapshotV2Stats{ + DashboardStats: *stats, + Uptime: int64(time.Since(h.startTime).Seconds()), + } + } + + if includeTrend { + trend, err := h.dashboardService.GetUsageTrendWithFilters( + c.Request.Context(), + startTime, + endTime, + granularity, + filters.UserID, + filters.APIKeyID, + filters.AccountID, + filters.GroupID, + filters.Model, + filters.RequestType, + filters.Stream, + filters.BillingType, + ) + if err != nil { + response.Error(c, 500, "Failed to get usage trend") + return + } + resp.Trend = trend + } + + if includeModels { + models, err := h.dashboardService.GetModelStatsWithFilters( + c.Request.Context(), + startTime, + endTime, + filters.UserID, + filters.APIKeyID, + filters.AccountID, + filters.GroupID, + filters.RequestType, + filters.Stream, + filters.BillingType, + ) + if err != nil { + response.Error(c, 500, "Failed to get model statistics") + return + } + resp.Models = models + } + + if includeGroups { + groups, err := h.dashboardService.GetGroupStatsWithFilters( + c.Request.Context(), + startTime, + endTime, + filters.UserID, + filters.APIKeyID, + filters.AccountID, + filters.GroupID, + filters.RequestType, + filters.Stream, + filters.BillingType, + ) + if err != nil { + response.Error(c, 500, "Failed to get group statistics") + return + } + resp.Groups = groups + } + + if includeUsersTrend { + usersTrend, err := h.dashboardService.GetUserUsageTrend( + c.Request.Context(), + startTime, + endTime, + granularity, + usersTrendLimit, + ) + if err != nil { + response.Error(c, 500, "Failed to get user usage trend") + return + } + resp.UsersTrend = usersTrend + } + + cached := dashboardSnapshotV2Cache.Set(cacheKey, resp) + if cached.ETag != "" { + c.Header("ETag", cached.ETag) + c.Header("Vary", "If-None-Match") + } + c.Header("X-Snapshot-Cache", "miss") + response.Success(c, resp) +} + +func parseDashboardSnapshotV2Filters(c *gin.Context) (*dashboardSnapshotV2Filters, error) { + filters := &dashboardSnapshotV2Filters{ + Model: strings.TrimSpace(c.Query("model")), + } + + if userIDStr := strings.TrimSpace(c.Query("user_id")); userIDStr != "" { + id, err := strconv.ParseInt(userIDStr, 10, 64) + if err != nil { + return nil, err + } + filters.UserID = id + } + if apiKeyIDStr := strings.TrimSpace(c.Query("api_key_id")); apiKeyIDStr != "" { + id, err := strconv.ParseInt(apiKeyIDStr, 10, 64) + if err != nil { + return nil, err + } + filters.APIKeyID = id + } + if accountIDStr := strings.TrimSpace(c.Query("account_id")); accountIDStr != "" { + id, err := strconv.ParseInt(accountIDStr, 10, 64) + if err != nil { + return nil, err + } + filters.AccountID = id + } + if groupIDStr := strings.TrimSpace(c.Query("group_id")); groupIDStr != "" { + id, err := strconv.ParseInt(groupIDStr, 10, 64) + if err != nil { + return nil, err + } + filters.GroupID = id + } + + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + return nil, err + } + value := int16(parsed) + filters.RequestType = &value + } else if streamStr := strings.TrimSpace(c.Query("stream")); streamStr != "" { + streamVal, err := strconv.ParseBool(streamStr) + if err != nil { + return nil, err + } + filters.Stream = &streamVal + } + + if billingTypeStr := strings.TrimSpace(c.Query("billing_type")); billingTypeStr != "" { + v, err := strconv.ParseInt(billingTypeStr, 10, 8) + if err != nil { + return nil, err + } + bt := int8(v) + filters.BillingType = &bt + } + + return filters, nil +} diff --git a/backend/internal/handler/admin/data_management_handler.go b/backend/internal/handler/admin/data_management_handler.go new file mode 100644 index 00000000..02fc766f --- /dev/null +++ b/backend/internal/handler/admin/data_management_handler.go @@ -0,0 +1,545 @@ +package admin + +import ( + "context" + "strconv" + "strings" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +type DataManagementHandler struct { + dataManagementService dataManagementService +} + +func NewDataManagementHandler(dataManagementService *service.DataManagementService) *DataManagementHandler { + return &DataManagementHandler{dataManagementService: dataManagementService} +} + +type dataManagementService interface { + GetConfig(ctx context.Context) (service.DataManagementConfig, error) + UpdateConfig(ctx context.Context, cfg service.DataManagementConfig) (service.DataManagementConfig, error) + ValidateS3(ctx context.Context, cfg service.DataManagementS3Config) (service.DataManagementTestS3Result, error) + CreateBackupJob(ctx context.Context, input service.DataManagementCreateBackupJobInput) (service.DataManagementBackupJob, error) + ListSourceProfiles(ctx context.Context, sourceType string) ([]service.DataManagementSourceProfile, error) + CreateSourceProfile(ctx context.Context, input service.DataManagementCreateSourceProfileInput) (service.DataManagementSourceProfile, error) + UpdateSourceProfile(ctx context.Context, input service.DataManagementUpdateSourceProfileInput) (service.DataManagementSourceProfile, error) + DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error + SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (service.DataManagementSourceProfile, error) + ListS3Profiles(ctx context.Context) ([]service.DataManagementS3Profile, error) + CreateS3Profile(ctx context.Context, input service.DataManagementCreateS3ProfileInput) (service.DataManagementS3Profile, error) + UpdateS3Profile(ctx context.Context, input service.DataManagementUpdateS3ProfileInput) (service.DataManagementS3Profile, error) + DeleteS3Profile(ctx context.Context, profileID string) error + SetActiveS3Profile(ctx context.Context, profileID string) (service.DataManagementS3Profile, error) + ListBackupJobs(ctx context.Context, input service.DataManagementListBackupJobsInput) (service.DataManagementListBackupJobsResult, error) + GetBackupJob(ctx context.Context, jobID string) (service.DataManagementBackupJob, error) + EnsureAgentEnabled(ctx context.Context) error + GetAgentHealth(ctx context.Context) service.DataManagementAgentHealth +} + +type TestS3ConnectionRequest struct { + Endpoint string `json:"endpoint"` + Region string `json:"region" binding:"required"` + Bucket string `json:"bucket" binding:"required"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + UseSSL bool `json:"use_ssl"` +} + +type CreateBackupJobRequest struct { + BackupType string `json:"backup_type" binding:"required,oneof=postgres redis full"` + UploadToS3 bool `json:"upload_to_s3"` + S3ProfileID string `json:"s3_profile_id"` + PostgresID string `json:"postgres_profile_id"` + RedisID string `json:"redis_profile_id"` + IdempotencyKey string `json:"idempotency_key"` +} + +type CreateSourceProfileRequest struct { + ProfileID string `json:"profile_id" binding:"required"` + Name string `json:"name" binding:"required"` + Config service.DataManagementSourceConfig `json:"config" binding:"required"` + SetActive bool `json:"set_active"` +} + +type UpdateSourceProfileRequest struct { + Name string `json:"name" binding:"required"` + Config service.DataManagementSourceConfig `json:"config" binding:"required"` +} + +type CreateS3ProfileRequest struct { + ProfileID string `json:"profile_id" binding:"required"` + Name string `json:"name" binding:"required"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + UseSSL bool `json:"use_ssl"` + SetActive bool `json:"set_active"` +} + +type UpdateS3ProfileRequest struct { + Name string `json:"name" binding:"required"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + UseSSL bool `json:"use_ssl"` +} + +func (h *DataManagementHandler) GetAgentHealth(c *gin.Context) { + health := h.getAgentHealth(c) + payload := gin.H{ + "enabled": health.Enabled, + "reason": health.Reason, + "socket_path": health.SocketPath, + } + if health.Agent != nil { + payload["agent"] = gin.H{ + "status": health.Agent.Status, + "version": health.Agent.Version, + "uptime_seconds": health.Agent.UptimeSeconds, + } + } + response.Success(c, payload) +} + +func (h *DataManagementHandler) GetConfig(c *gin.Context) { + if !h.requireAgentEnabled(c) { + return + } + cfg, err := h.dataManagementService.GetConfig(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +func (h *DataManagementHandler) UpdateConfig(c *gin.Context) { + var req service.DataManagementConfig + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !h.requireAgentEnabled(c) { + return + } + cfg, err := h.dataManagementService.UpdateConfig(c.Request.Context(), req) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +func (h *DataManagementHandler) TestS3(c *gin.Context) { + var req TestS3ConnectionRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !h.requireAgentEnabled(c) { + return + } + result, err := h.dataManagementService.ValidateS3(c.Request.Context(), service.DataManagementS3Config{ + Enabled: true, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + UseSSL: req.UseSSL, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"ok": result.OK, "message": result.Message}) +} + +func (h *DataManagementHandler) CreateBackupJob(c *gin.Context) { + var req CreateBackupJobRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + req.IdempotencyKey = normalizeBackupIdempotencyKey(c.GetHeader("X-Idempotency-Key"), req.IdempotencyKey) + if !h.requireAgentEnabled(c) { + return + } + + triggeredBy := "admin:unknown" + if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok { + triggeredBy = "admin:" + strconv.FormatInt(subject.UserID, 10) + } + job, err := h.dataManagementService.CreateBackupJob(c.Request.Context(), service.DataManagementCreateBackupJobInput{ + BackupType: req.BackupType, + UploadToS3: req.UploadToS3, + S3ProfileID: req.S3ProfileID, + PostgresID: req.PostgresID, + RedisID: req.RedisID, + TriggeredBy: triggeredBy, + IdempotencyKey: req.IdempotencyKey, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"job_id": job.JobID, "status": job.Status}) +} + +func (h *DataManagementHandler) ListSourceProfiles(c *gin.Context) { + sourceType := strings.TrimSpace(c.Param("source_type")) + if sourceType == "" { + response.BadRequest(c, "Invalid source_type") + return + } + if sourceType != "postgres" && sourceType != "redis" { + response.BadRequest(c, "source_type must be postgres or redis") + return + } + + if !h.requireAgentEnabled(c) { + return + } + items, err := h.dataManagementService.ListSourceProfiles(c.Request.Context(), sourceType) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"items": items}) +} + +func (h *DataManagementHandler) CreateSourceProfile(c *gin.Context) { + sourceType := strings.TrimSpace(c.Param("source_type")) + if sourceType != "postgres" && sourceType != "redis" { + response.BadRequest(c, "source_type must be postgres or redis") + return + } + + var req CreateSourceProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !h.requireAgentEnabled(c) { + return + } + profile, err := h.dataManagementService.CreateSourceProfile(c.Request.Context(), service.DataManagementCreateSourceProfileInput{ + SourceType: sourceType, + ProfileID: req.ProfileID, + Name: req.Name, + Config: req.Config, + SetActive: req.SetActive, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) UpdateSourceProfile(c *gin.Context) { + sourceType := strings.TrimSpace(c.Param("source_type")) + if sourceType != "postgres" && sourceType != "redis" { + response.BadRequest(c, "source_type must be postgres or redis") + return + } + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + var req UpdateSourceProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !h.requireAgentEnabled(c) { + return + } + profile, err := h.dataManagementService.UpdateSourceProfile(c.Request.Context(), service.DataManagementUpdateSourceProfileInput{ + SourceType: sourceType, + ProfileID: profileID, + Name: req.Name, + Config: req.Config, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) DeleteSourceProfile(c *gin.Context) { + sourceType := strings.TrimSpace(c.Param("source_type")) + if sourceType != "postgres" && sourceType != "redis" { + response.BadRequest(c, "source_type must be postgres or redis") + return + } + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + if !h.requireAgentEnabled(c) { + return + } + if err := h.dataManagementService.DeleteSourceProfile(c.Request.Context(), sourceType, profileID); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"deleted": true}) +} + +func (h *DataManagementHandler) SetActiveSourceProfile(c *gin.Context) { + sourceType := strings.TrimSpace(c.Param("source_type")) + if sourceType != "postgres" && sourceType != "redis" { + response.BadRequest(c, "source_type must be postgres or redis") + return + } + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + if !h.requireAgentEnabled(c) { + return + } + profile, err := h.dataManagementService.SetActiveSourceProfile(c.Request.Context(), sourceType, profileID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) ListS3Profiles(c *gin.Context) { + if !h.requireAgentEnabled(c) { + return + } + + items, err := h.dataManagementService.ListS3Profiles(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"items": items}) +} + +func (h *DataManagementHandler) CreateS3Profile(c *gin.Context) { + var req CreateS3ProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !h.requireAgentEnabled(c) { + return + } + + profile, err := h.dataManagementService.CreateS3Profile(c.Request.Context(), service.DataManagementCreateS3ProfileInput{ + ProfileID: req.ProfileID, + Name: req.Name, + SetActive: req.SetActive, + S3: service.DataManagementS3Config{ + Enabled: req.Enabled, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + UseSSL: req.UseSSL, + }, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) UpdateS3Profile(c *gin.Context) { + var req UpdateS3ProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + if !h.requireAgentEnabled(c) { + return + } + + profile, err := h.dataManagementService.UpdateS3Profile(c.Request.Context(), service.DataManagementUpdateS3ProfileInput{ + ProfileID: profileID, + Name: req.Name, + S3: service.DataManagementS3Config{ + Enabled: req.Enabled, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + UseSSL: req.UseSSL, + }, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) DeleteS3Profile(c *gin.Context) { + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + if !h.requireAgentEnabled(c) { + return + } + if err := h.dataManagementService.DeleteS3Profile(c.Request.Context(), profileID); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"deleted": true}) +} + +func (h *DataManagementHandler) SetActiveS3Profile(c *gin.Context) { + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + if !h.requireAgentEnabled(c) { + return + } + profile, err := h.dataManagementService.SetActiveS3Profile(c.Request.Context(), profileID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) ListBackupJobs(c *gin.Context) { + if !h.requireAgentEnabled(c) { + return + } + + pageSize := int32(20) + if raw := strings.TrimSpace(c.Query("page_size")); raw != "" { + v, err := strconv.Atoi(raw) + if err != nil || v <= 0 { + response.BadRequest(c, "Invalid page_size") + return + } + pageSize = int32(v) + } + + result, err := h.dataManagementService.ListBackupJobs(c.Request.Context(), service.DataManagementListBackupJobsInput{ + PageSize: pageSize, + PageToken: c.Query("page_token"), + Status: c.Query("status"), + BackupType: c.Query("backup_type"), + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} + +func (h *DataManagementHandler) GetBackupJob(c *gin.Context) { + jobID := strings.TrimSpace(c.Param("job_id")) + if jobID == "" { + response.BadRequest(c, "Invalid backup job ID") + return + } + + if !h.requireAgentEnabled(c) { + return + } + job, err := h.dataManagementService.GetBackupJob(c.Request.Context(), jobID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, job) +} + +func (h *DataManagementHandler) requireAgentEnabled(c *gin.Context) bool { + if h.dataManagementService == nil { + err := infraerrors.ServiceUnavailable( + service.DataManagementAgentUnavailableReason, + "data management agent service is not configured", + ).WithMetadata(map[string]string{"socket_path": service.DefaultDataManagementAgentSocketPath}) + response.ErrorFrom(c, err) + return false + } + + if err := h.dataManagementService.EnsureAgentEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return false + } + + return true +} + +func (h *DataManagementHandler) getAgentHealth(c *gin.Context) service.DataManagementAgentHealth { + if h.dataManagementService == nil { + return service.DataManagementAgentHealth{ + Enabled: false, + Reason: service.DataManagementAgentUnavailableReason, + SocketPath: service.DefaultDataManagementAgentSocketPath, + } + } + return h.dataManagementService.GetAgentHealth(c.Request.Context()) +} + +func normalizeBackupIdempotencyKey(headerValue, bodyValue string) string { + headerKey := strings.TrimSpace(headerValue) + if headerKey != "" { + return headerKey + } + return strings.TrimSpace(bodyValue) +} diff --git a/backend/internal/handler/admin/data_management_handler_test.go b/backend/internal/handler/admin/data_management_handler_test.go new file mode 100644 index 00000000..ce8ee835 --- /dev/null +++ b/backend/internal/handler/admin/data_management_handler_test.go @@ -0,0 +1,78 @@ +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "path/filepath" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type apiEnvelope struct { + Code int `json:"code"` + Message string `json:"message"` + Reason string `json:"reason"` + Data json.RawMessage `json:"data"` +} + +func TestDataManagementHandler_AgentHealthAlways200(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond) + h := NewDataManagementHandler(svc) + + r := gin.New() + r.GET("/api/v1/admin/data-management/agent/health", h.GetAgentHealth) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/agent/health", nil) + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var envelope apiEnvelope + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope)) + require.Equal(t, 0, envelope.Code) + + var data struct { + Enabled bool `json:"enabled"` + Reason string `json:"reason"` + SocketPath string `json:"socket_path"` + } + require.NoError(t, json.Unmarshal(envelope.Data, &data)) + require.False(t, data.Enabled) + require.Equal(t, service.DataManagementDeprecatedReason, data.Reason) + require.Equal(t, svc.SocketPath(), data.SocketPath) +} + +func TestDataManagementHandler_NonHealthRouteReturns503WhenDisabled(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond) + h := NewDataManagementHandler(svc) + + r := gin.New() + r.GET("/api/v1/admin/data-management/config", h.GetConfig) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/config", nil) + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusServiceUnavailable, rec.Code) + + var envelope apiEnvelope + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope)) + require.Equal(t, http.StatusServiceUnavailable, envelope.Code) + require.Equal(t, service.DataManagementDeprecatedReason, envelope.Reason) +} + +func TestNormalizeBackupIdempotencyKey(t *testing.T) { + require.Equal(t, "from-header", normalizeBackupIdempotencyKey("from-header", "from-body")) + require.Equal(t, "from-body", normalizeBackupIdempotencyKey(" ", " from-body ")) + require.Equal(t, "", normalizeBackupIdempotencyKey("", "")) +} diff --git a/backend/internal/handler/admin/gemini_oauth_handler.go b/backend/internal/handler/admin/gemini_oauth_handler.go index 50caaa26..8c398a1e 100644 --- a/backend/internal/handler/admin/gemini_oauth_handler.go +++ b/backend/internal/handler/admin/gemini_oauth_handler.go @@ -61,7 +61,11 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) { if err != nil { msg := err.Error() // Treat missing/invalid OAuth client configuration as a user/config error. - if strings.Contains(msg, "OAuth client not configured") || strings.Contains(msg, "requires your own OAuth Client") { + if strings.Contains(msg, "OAuth client not configured") || + strings.Contains(msg, "requires your own OAuth Client") || + strings.Contains(msg, "requires a custom OAuth Client") || + strings.Contains(msg, "GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING") || + strings.Contains(msg, "built-in Gemini CLI OAuth client_secret is not configured") { response.BadRequest(c, "Failed to generate auth URL: "+msg) return } diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 7daaf281..1edf4dcc 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler { type CreateGroupRequest struct { Name string `json:"name" binding:"required"` Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"` RateMultiplier float64 `json:"rate_multiplier"` IsExclusive bool `json:"is_exclusive"` SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"` @@ -38,6 +38,10 @@ type CreateGroupRequest struct { ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` @@ -47,6 +51,8 @@ type CreateGroupRequest struct { MCPXMLInject *bool `json:"mcp_xml_inject"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string `json:"supported_model_scopes"` + // Sora 存储配额 + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` // 从指定分组复制账号(创建后自动绑定) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -55,7 +61,7 @@ type CreateGroupRequest struct { type UpdateGroupRequest struct { Name string `json:"name"` Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"` RateMultiplier *float64 `json:"rate_multiplier"` IsExclusive *bool `json:"is_exclusive"` Status string `json:"status" binding:"omitempty,oneof=active inactive"` @@ -67,6 +73,10 @@ type UpdateGroupRequest struct { ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` ClaudeCodeOnly *bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` @@ -76,6 +86,8 @@ type UpdateGroupRequest struct { MCPXMLInject *bool `json:"mcp_xml_inject"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes *[]string `json:"supported_model_scopes"` + // Sora 存储配额 + SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"` // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -179,6 +191,10 @@ func (h *GroupHandler) Create(c *gin.Context) { ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, + SoraImagePrice360: req.SoraImagePrice360, + SoraImagePrice540: req.SoraImagePrice540, + SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, ClaudeCodeOnly: req.ClaudeCodeOnly, FallbackGroupID: req.FallbackGroupID, FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, @@ -186,6 +202,7 @@ func (h *GroupHandler) Create(c *gin.Context) { ModelRoutingEnabled: req.ModelRoutingEnabled, MCPXMLInject: req.MCPXMLInject, SupportedModelScopes: req.SupportedModelScopes, + SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { @@ -225,6 +242,10 @@ func (h *GroupHandler) Update(c *gin.Context) { ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, + SoraImagePrice360: req.SoraImagePrice360, + SoraImagePrice540: req.SoraImagePrice540, + SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, ClaudeCodeOnly: req.ClaudeCodeOnly, FallbackGroupID: req.FallbackGroupID, FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, @@ -232,6 +253,7 @@ func (h *GroupHandler) Update(c *gin.Context) { ModelRoutingEnabled: req.ModelRoutingEnabled, MCPXMLInject: req.MCPXMLInject, SupportedModelScopes: req.SupportedModelScopes, + SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { diff --git a/backend/internal/handler/admin/id_list_utils.go b/backend/internal/handler/admin/id_list_utils.go new file mode 100644 index 00000000..2aeefe38 --- /dev/null +++ b/backend/internal/handler/admin/id_list_utils.go @@ -0,0 +1,25 @@ +package admin + +import "sort" + +func normalizeInt64IDList(ids []int64) []int64 { + if len(ids) == 0 { + return nil + } + + out := make([]int64, 0, len(ids)) + seen := make(map[int64]struct{}, len(ids)) + for _, id := range ids { + if id <= 0 { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + out = append(out, id) + } + + sort.Slice(out, func(i, j int) bool { return out[i] < out[j] }) + return out +} diff --git a/backend/internal/handler/admin/id_list_utils_test.go b/backend/internal/handler/admin/id_list_utils_test.go new file mode 100644 index 00000000..aa65d5c0 --- /dev/null +++ b/backend/internal/handler/admin/id_list_utils_test.go @@ -0,0 +1,57 @@ +//go:build unit + +package admin + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNormalizeInt64IDList(t *testing.T) { + tests := []struct { + name string + in []int64 + want []int64 + }{ + {"nil input", nil, nil}, + {"empty input", []int64{}, nil}, + {"single element", []int64{5}, []int64{5}}, + {"already sorted unique", []int64{1, 2, 3}, []int64{1, 2, 3}}, + {"duplicates removed", []int64{3, 1, 3, 2, 1}, []int64{1, 2, 3}}, + {"zero filtered", []int64{0, 1, 2}, []int64{1, 2}}, + {"negative filtered", []int64{-5, -1, 3}, []int64{3}}, + {"all invalid", []int64{0, -1, -2}, []int64{}}, + {"sorted output", []int64{9, 3, 7, 1}, []int64{1, 3, 7, 9}}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := normalizeInt64IDList(tc.in) + if tc.want == nil { + require.Nil(t, got) + } else { + require.Equal(t, tc.want, got) + } + }) + } +} + +func TestBuildAccountTodayStatsBatchCacheKey(t *testing.T) { + tests := []struct { + name string + ids []int64 + want string + }{ + {"empty", nil, "accounts_today_stats_empty"}, + {"single", []int64{42}, "accounts_today_stats:42"}, + {"multiple", []int64{1, 2, 3}, "accounts_today_stats:1,2,3"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := buildAccountTodayStatsBatchCacheKey(tc.ids) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/backend/internal/handler/admin/idempotency_helper.go b/backend/internal/handler/admin/idempotency_helper.go new file mode 100644 index 00000000..aa8eeaaf --- /dev/null +++ b/backend/internal/handler/admin/idempotency_helper.go @@ -0,0 +1,115 @@ +package admin + +import ( + "context" + "strconv" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +type idempotencyStoreUnavailableMode int + +const ( + idempotencyStoreUnavailableFailClose idempotencyStoreUnavailableMode = iota + idempotencyStoreUnavailableFailOpen +) + +func executeAdminIdempotent( + c *gin.Context, + scope string, + payload any, + ttl time.Duration, + execute func(context.Context) (any, error), +) (*service.IdempotencyExecuteResult, error) { + coordinator := service.DefaultIdempotencyCoordinator() + if coordinator == nil { + data, err := execute(c.Request.Context()) + if err != nil { + return nil, err + } + return &service.IdempotencyExecuteResult{Data: data}, nil + } + + actorScope := "admin:0" + if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok { + actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10) + } + + return coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{ + Scope: scope, + ActorScope: actorScope, + Method: c.Request.Method, + Route: c.FullPath(), + IdempotencyKey: c.GetHeader("Idempotency-Key"), + Payload: payload, + RequireKey: true, + TTL: ttl, + }, execute) +} + +func executeAdminIdempotentJSON( + c *gin.Context, + scope string, + payload any, + ttl time.Duration, + execute func(context.Context) (any, error), +) { + executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailClose, execute) +} + +func executeAdminIdempotentJSONFailOpenOnStoreUnavailable( + c *gin.Context, + scope string, + payload any, + ttl time.Duration, + execute func(context.Context) (any, error), +) { + executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailOpen, execute) +} + +func executeAdminIdempotentJSONWithMode( + c *gin.Context, + scope string, + payload any, + ttl time.Duration, + mode idempotencyStoreUnavailableMode, + execute func(context.Context) (any, error), +) { + result, err := executeAdminIdempotent(c, scope, payload, ttl, execute) + if err != nil { + if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) { + strategy := "fail_close" + if mode == idempotencyStoreUnavailableFailOpen { + strategy = "fail_open" + } + service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_"+strategy) + logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=%s", c.Request.Method, c.FullPath(), scope, strategy) + if mode == idempotencyStoreUnavailableFailOpen { + data, fallbackErr := execute(c.Request.Context()) + if fallbackErr != nil { + response.ErrorFrom(c, fallbackErr) + return + } + c.Header("X-Idempotency-Degraded", "store-unavailable") + response.Success(c, data) + return + } + } + if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } + response.ErrorFrom(c, err) + return + } + if result != nil && result.Replayed { + c.Header("X-Idempotency-Replayed", "true") + } + response.Success(c, result.Data) +} diff --git a/backend/internal/handler/admin/idempotency_helper_test.go b/backend/internal/handler/admin/idempotency_helper_test.go new file mode 100644 index 00000000..7dd86e16 --- /dev/null +++ b/backend/internal/handler/admin/idempotency_helper_test.go @@ -0,0 +1,285 @@ +package admin + +import ( + "bytes" + "context" + "errors" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type storeUnavailableRepoStub struct{} + +func (storeUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) { + return false, errors.New("store unavailable") +} +func (storeUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) { + return nil, errors.New("store unavailable") +} +func (storeUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (storeUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (storeUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error { + return errors.New("store unavailable") +} +func (storeUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return errors.New("store unavailable") +} +func (storeUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) { + return 0, errors.New("store unavailable") +} + +func TestExecuteAdminIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) { + gin.SetMode(gin.TestMode) + service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig())) + t.Cleanup(func() { + service.SetDefaultIdempotencyCoordinator(nil) + }) + + var executed int + router := gin.New() + router.POST("/idempotent", func(c *gin.Context) { + executeAdminIdempotentJSON(c, "admin.test.high", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed++ + return gin.H{"ok": true}, nil + }) + }) + + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", "test-key-1") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusServiceUnavailable, rec.Code) + require.Equal(t, 0, executed, "fail-close should block business execution when idempotency store is unavailable") +} + +func TestExecuteAdminIdempotentJSONFailOpenOnStoreUnavailable(t *testing.T) { + gin.SetMode(gin.TestMode) + service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig())) + t.Cleanup(func() { + service.SetDefaultIdempotencyCoordinator(nil) + }) + + var executed int + router := gin.New() + router.POST("/idempotent", func(c *gin.Context) { + executeAdminIdempotentJSONFailOpenOnStoreUnavailable(c, "admin.test.medium", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed++ + return gin.H{"ok": true}, nil + }) + }) + + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", "test-key-2") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "store-unavailable", rec.Header().Get("X-Idempotency-Degraded")) + require.Equal(t, 1, executed, "fail-open strategy should allow semantic idempotent path to continue") +} + +type memoryIdempotencyRepoStub struct { + mu sync.Mutex + nextID int64 + data map[string]*service.IdempotencyRecord +} + +func newMemoryIdempotencyRepoStub() *memoryIdempotencyRepoStub { + return &memoryIdempotencyRepoStub{ + nextID: 1, + data: make(map[string]*service.IdempotencyRecord), + } +} + +func (r *memoryIdempotencyRepoStub) key(scope, keyHash string) string { + return scope + "|" + keyHash +} + +func (r *memoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord { + if in == nil { + return nil + } + out := *in + if in.LockedUntil != nil { + v := *in.LockedUntil + out.LockedUntil = &v + } + if in.ResponseBody != nil { + v := *in.ResponseBody + out.ResponseBody = &v + } + if in.ResponseStatus != nil { + v := *in.ResponseStatus + out.ResponseStatus = &v + } + if in.ErrorReason != nil { + v := *in.ErrorReason + out.ErrorReason = &v + } + return &out +} + +func (r *memoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + k := r.key(record.Scope, record.IdempotencyKeyHash) + if _, ok := r.data[k]; ok { + return false, nil + } + cp := r.clone(record) + cp.ID = r.nextID + r.nextID++ + r.data[k] = cp + record.ID = cp.ID + return true, nil +} + +func (r *memoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) { + r.mu.Lock() + defer r.mu.Unlock() + return r.clone(r.data[r.key(scope, keyHash)]), nil +} + +func (r *memoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != fromStatus { + return false, nil + } + if rec.LockedUntil != nil && rec.LockedUntil.After(now) { + return false, nil + } + rec.Status = service.IdempotencyStatusProcessing + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + rec.ErrorReason = nil + return true, nil + } + return false, nil +} + +func (r *memoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint { + return false, nil + } + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + return true, nil + } + return false, nil +} + +func (r *memoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = service.IdempotencyStatusSucceeded + rec.LockedUntil = nil + rec.ExpiresAt = expiresAt + rec.ResponseStatus = &responseStatus + rec.ResponseBody = &responseBody + rec.ErrorReason = nil + return nil + } + return nil +} + +func (r *memoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = service.IdempotencyStatusFailedRetryable + rec.LockedUntil = &lockedUntil + rec.ExpiresAt = expiresAt + rec.ErrorReason = &errorReason + return nil + } + return nil +} + +func (r *memoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) { + return 0, nil +} + +func TestExecuteAdminIdempotentJSONConcurrentRetryOnlyOneSideEffect(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := newMemoryIdempotencyRepoStub() + cfg := service.DefaultIdempotencyConfig() + cfg.ProcessingTimeout = 2 * time.Second + service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg)) + t.Cleanup(func() { + service.SetDefaultIdempotencyCoordinator(nil) + }) + + var executed atomic.Int32 + router := gin.New() + router.POST("/idempotent", func(c *gin.Context) { + executeAdminIdempotentJSON(c, "admin.test.concurrent", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed.Add(1) + time.Sleep(120 * time.Millisecond) + return gin.H{"ok": true}, nil + }) + }) + + call := func() (int, http.Header) { + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", "same-key") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + return rec.Code, rec.Header() + } + + var status1, status2 int + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + status1, _ = call() + }() + go func() { + defer wg.Done() + status2, _ = call() + }() + wg.Wait() + + require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1) + require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2) + require.Equal(t, int32(1), executed.Load(), "same idempotency key should execute side-effect only once") + + status3, headers3 := call() + require.Equal(t, http.StatusOK, status3) + require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed")) + require.Equal(t, int32(1), executed.Load()) +} diff --git a/backend/internal/handler/admin/openai_oauth_handler.go b/backend/internal/handler/admin/openai_oauth_handler.go index ed86fea9..5d354fd3 100644 --- a/backend/internal/handler/admin/openai_oauth_handler.go +++ b/backend/internal/handler/admin/openai_oauth_handler.go @@ -2,8 +2,10 @@ package admin import ( "strconv" + "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -16,6 +18,13 @@ type OpenAIOAuthHandler struct { adminService service.AdminService } +func oauthPlatformFromPath(c *gin.Context) string { + if strings.Contains(c.FullPath(), "/admin/sora/") { + return service.PlatformSora + } + return service.PlatformOpenAI +} + // NewOpenAIOAuthHandler creates a new OpenAI OAuth handler func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler { return &OpenAIOAuthHandler{ @@ -39,7 +48,12 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) { req = OpenAIGenerateAuthURLRequest{} } - result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI) + result, err := h.openaiOAuthService.GenerateAuthURL( + c.Request.Context(), + req.ProxyID, + req.RedirectURI, + oauthPlatformFromPath(c), + ) if err != nil { response.ErrorFrom(c, err) return @@ -52,6 +66,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) { type OpenAIExchangeCodeRequest struct { SessionID string `json:"session_id" binding:"required"` Code string `json:"code" binding:"required"` + State string `json:"state" binding:"required"` RedirectURI string `json:"redirect_uri"` ProxyID *int64 `json:"proxy_id"` } @@ -68,6 +83,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) { tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ SessionID: req.SessionID, Code: req.Code, + State: req.State, RedirectURI: req.RedirectURI, ProxyID: req.ProxyID, }) @@ -81,18 +97,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) { // OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token type OpenAIRefreshTokenRequest struct { - RefreshToken string `json:"refresh_token" binding:"required"` + RefreshToken string `json:"refresh_token"` + RT string `json:"rt"` + ClientID string `json:"client_id"` ProxyID *int64 `json:"proxy_id"` } // RefreshToken refreshes an OpenAI OAuth token // POST /api/v1/admin/openai/refresh-token +// POST /api/v1/admin/sora/rt2at func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { var req OpenAIRefreshTokenRequest if err := c.ShouldBindJSON(&req); err != nil { response.BadRequest(c, "Invalid request: "+err.Error()) return } + refreshToken := strings.TrimSpace(req.RefreshToken) + if refreshToken == "" { + refreshToken = strings.TrimSpace(req.RT) + } + if refreshToken == "" { + response.BadRequest(c, "refresh_token is required") + return + } var proxyURL string if req.ProxyID != nil { @@ -102,7 +129,14 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { } } - tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL) + // 未指定 client_id 时,根据请求路径平台自动设置默认值,避免 repository 层盲猜 + clientID := strings.TrimSpace(req.ClientID) + if clientID == "" { + platform := oauthPlatformFromPath(c) + clientID, _ = openai.OAuthClientConfigByPlatform(platform) + } + + tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, clientID) if err != nil { response.ErrorFrom(c, err) return @@ -111,8 +145,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { response.Success(c, tokenInfo) } -// RefreshAccountToken refreshes token for a specific OpenAI account +// ExchangeSoraSessionToken exchanges Sora session token to access token +// POST /api/v1/admin/sora/st2at +func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) { + var req struct { + SessionToken string `json:"session_token"` + ST string `json:"st"` + ProxyID *int64 `json:"proxy_id"` + } + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + sessionToken := strings.TrimSpace(req.SessionToken) + if sessionToken == "" { + sessionToken = strings.TrimSpace(req.ST) + } + if sessionToken == "" { + response.BadRequest(c, "session_token is required") + return + } + + tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, tokenInfo) +} + +// RefreshAccountToken refreshes token for a specific OpenAI/Sora account // POST /api/v1/admin/openai/accounts/:id/refresh +// POST /api/v1/admin/sora/accounts/:id/refresh func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { @@ -127,9 +192,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { return } - // Ensure account is OpenAI platform - if !account.IsOpenAI() { - response.BadRequest(c, "Account is not an OpenAI account") + platform := oauthPlatformFromPath(c) + if account.Platform != platform { + response.BadRequest(c, "Account platform does not match OAuth endpoint") return } @@ -167,12 +232,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { response.Success(c, dto.AccountFromService(updatedAccount)) } -// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info +// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info // POST /api/v1/admin/openai/create-from-oauth +// POST /api/v1/admin/sora/create-from-oauth func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { var req struct { SessionID string `json:"session_id" binding:"required"` Code string `json:"code" binding:"required"` + State string `json:"state" binding:"required"` RedirectURI string `json:"redirect_uri"` ProxyID *int64 `json:"proxy_id"` Name string `json:"name"` @@ -189,6 +256,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ SessionID: req.SessionID, Code: req.Code, + State: req.State, RedirectURI: req.RedirectURI, ProxyID: req.ProxyID, }) @@ -200,19 +268,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { // Build credentials from token info credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo) + platform := oauthPlatformFromPath(c) + // Use email as default name if not provided name := req.Name if name == "" && tokenInfo.Email != "" { name = tokenInfo.Email } if name == "" { - name = "OpenAI OAuth Account" + if platform == service.PlatformSora { + name = "Sora OAuth Account" + } else { + name = "OpenAI OAuth Account" + } } // Create account account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{ Name: name, - Platform: "openai", + Platform: platform, Type: "oauth", Credentials: credentials, ProxyID: req.ProxyID, diff --git a/backend/internal/handler/admin/ops_dashboard_handler.go b/backend/internal/handler/admin/ops_dashboard_handler.go index 2c87f734..01f7bc2b 100644 --- a/backend/internal/handler/admin/ops_dashboard_handler.go +++ b/backend/internal/handler/admin/ops_dashboard_handler.go @@ -1,6 +1,7 @@ package admin import ( + "fmt" "net/http" "strconv" "strings" @@ -218,6 +219,115 @@ func (h *OpsHandler) GetDashboardErrorDistribution(c *gin.Context) { response.Success(c, data) } +// GetDashboardOpenAITokenStats returns OpenAI token efficiency stats grouped by model. +// GET /api/v1/admin/ops/dashboard/openai-token-stats +func (h *OpsHandler) GetDashboardOpenAITokenStats(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + filter, err := parseOpsOpenAITokenStatsFilter(c) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + data, err := h.opsService.GetOpenAITokenStats(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, data) +} + +func parseOpsOpenAITokenStatsFilter(c *gin.Context) (*service.OpsOpenAITokenStatsFilter, error) { + if c == nil { + return nil, fmt.Errorf("invalid request") + } + + timeRange := strings.TrimSpace(c.Query("time_range")) + if timeRange == "" { + timeRange = "30d" + } + dur, ok := parseOpsOpenAITokenStatsDuration(timeRange) + if !ok { + return nil, fmt.Errorf("invalid time_range") + } + end := time.Now().UTC() + start := end.Add(-dur) + + filter := &service.OpsOpenAITokenStatsFilter{ + TimeRange: timeRange, + StartTime: start, + EndTime: end, + Platform: strings.TrimSpace(c.Query("platform")), + } + + if v := strings.TrimSpace(c.Query("group_id")); v != "" { + id, err := strconv.ParseInt(v, 10, 64) + if err != nil || id <= 0 { + return nil, fmt.Errorf("invalid group_id") + } + filter.GroupID = &id + } + + topNRaw := strings.TrimSpace(c.Query("top_n")) + pageRaw := strings.TrimSpace(c.Query("page")) + pageSizeRaw := strings.TrimSpace(c.Query("page_size")) + if topNRaw != "" && (pageRaw != "" || pageSizeRaw != "") { + return nil, fmt.Errorf("invalid query: top_n cannot be used with page/page_size") + } + + if topNRaw != "" { + topN, err := strconv.Atoi(topNRaw) + if err != nil || topN < 1 || topN > 100 { + return nil, fmt.Errorf("invalid top_n") + } + filter.TopN = topN + return filter, nil + } + + filter.Page = 1 + filter.PageSize = 20 + if pageRaw != "" { + page, err := strconv.Atoi(pageRaw) + if err != nil || page < 1 { + return nil, fmt.Errorf("invalid page") + } + filter.Page = page + } + if pageSizeRaw != "" { + pageSize, err := strconv.Atoi(pageSizeRaw) + if err != nil || pageSize < 1 || pageSize > 100 { + return nil, fmt.Errorf("invalid page_size") + } + filter.PageSize = pageSize + } + return filter, nil +} + +func parseOpsOpenAITokenStatsDuration(v string) (time.Duration, bool) { + switch strings.TrimSpace(v) { + case "30m": + return 30 * time.Minute, true + case "1h": + return time.Hour, true + case "1d": + return 24 * time.Hour, true + case "15d": + return 15 * 24 * time.Hour, true + case "30d": + return 30 * 24 * time.Hour, true + default: + return 0, false + } +} + func pickThroughputBucketSeconds(window time.Duration) int { // Keep buckets predictable and avoid huge responses. switch { diff --git a/backend/internal/handler/admin/ops_runtime_logging_handler_test.go b/backend/internal/handler/admin/ops_runtime_logging_handler_test.go new file mode 100644 index 00000000..0e84b4f9 --- /dev/null +++ b/backend/internal/handler/admin/ops_runtime_logging_handler_test.go @@ -0,0 +1,173 @@ +package admin + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +type testSettingRepo struct { + values map[string]string +} + +func newTestSettingRepo() *testSettingRepo { + return &testSettingRepo{values: map[string]string{}} +} + +func (s *testSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) { + v, err := s.GetValue(ctx, key) + if err != nil { + return nil, err + } + return &service.Setting{Key: key, Value: v}, nil +} +func (s *testSettingRepo) GetValue(ctx context.Context, key string) (string, error) { + v, ok := s.values[key] + if !ok { + return "", service.ErrSettingNotFound + } + return v, nil +} +func (s *testSettingRepo) Set(ctx context.Context, key, value string) error { + s.values[key] = value + return nil +} +func (s *testSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, k := range keys { + if v, ok := s.values[k]; ok { + out[k] = v + } + } + return out, nil +} +func (s *testSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error { + for k, v := range settings { + s.values[k] = v + } + return nil +} +func (s *testSettingRepo) GetAll(ctx context.Context) (map[string]string, error) { + out := make(map[string]string, len(s.values)) + for k, v := range s.values { + out[k] = v + } + return out, nil +} +func (s *testSettingRepo) Delete(ctx context.Context, key string) error { + delete(s.values, key) + return nil +} + +func newOpsRuntimeRouter(handler *OpsHandler, withUser bool) *gin.Engine { + gin.SetMode(gin.TestMode) + r := gin.New() + if withUser { + r.Use(func(c *gin.Context) { + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 7}) + c.Next() + }) + } + r.GET("/runtime/logging", handler.GetRuntimeLogConfig) + r.PUT("/runtime/logging", handler.UpdateRuntimeLogConfig) + r.POST("/runtime/logging/reset", handler.ResetRuntimeLogConfig) + return r +} + +func newRuntimeOpsService(t *testing.T) *service.OpsService { + t.Helper() + if err := logger.Init(logger.InitOptions{ + Level: "info", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: false, + ToFile: false, + }, + }); err != nil { + t.Fatalf("init logger: %v", err) + } + + settingRepo := newTestSettingRepo() + cfg := &config.Config{ + Ops: config.OpsConfig{Enabled: true}, + Log: config.LogConfig{ + Level: "info", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + } + return service.NewOpsService(nil, settingRepo, cfg, nil, nil, nil, nil, nil, nil, nil, nil) +} + +func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) { + h := NewOpsHandler(newRuntimeOpsService(t)) + r := newOpsRuntimeRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/runtime/logging", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("status=%d, want 200", w.Code) + } +} + +func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) { + h := NewOpsHandler(newRuntimeOpsService(t)) + r := newOpsRuntimeRouter(h, false) + + body := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}` + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Fatalf("status=%d, want 401", w.Code) + } +} + +func TestOpsRuntimeLoggingHandler_UpdateAndResetSuccess(t *testing.T) { + h := NewOpsHandler(newRuntimeOpsService(t)) + r := newOpsRuntimeRouter(h, true) + + payload := map[string]any{ + "level": "debug", + "enable_sampling": false, + "sampling_initial": 100, + "sampling_thereafter": 100, + "caller": true, + "stacktrace_level": "error", + "retention_days": 30, + } + raw, _ := json.Marshal(payload) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewReader(raw)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("update status=%d, want 200, body=%s", w.Code, w.Body.String()) + } + + w = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/runtime/logging/reset", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("reset status=%d, want 200, body=%s", w.Code, w.Body.String()) + } +} diff --git a/backend/internal/handler/admin/ops_settings_handler.go b/backend/internal/handler/admin/ops_settings_handler.go index ebc8bf49..226b89f3 100644 --- a/backend/internal/handler/admin/ops_settings_handler.go +++ b/backend/internal/handler/admin/ops_settings_handler.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" ) @@ -101,6 +102,84 @@ func (h *OpsHandler) UpdateAlertRuntimeSettings(c *gin.Context) { response.Success(c, updated) } +// GetRuntimeLogConfig returns runtime log config (DB-backed). +// GET /api/v1/admin/ops/runtime/logging +func (h *OpsHandler) GetRuntimeLogConfig(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + cfg, err := h.opsService.GetRuntimeLogConfig(c.Request.Context()) + if err != nil { + response.Error(c, http.StatusInternalServerError, "Failed to get runtime log config") + return + } + response.Success(c, cfg) +} + +// UpdateRuntimeLogConfig updates runtime log config and applies changes immediately. +// PUT /api/v1/admin/ops/runtime/logging +func (h *OpsHandler) UpdateRuntimeLogConfig(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + var req service.OpsRuntimeLogConfig + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request body") + return + } + + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Error(c, http.StatusUnauthorized, "Unauthorized") + return + } + + updated, err := h.opsService.UpdateRuntimeLogConfig(c.Request.Context(), &req, subject.UserID) + if err != nil { + response.Error(c, http.StatusBadRequest, err.Error()) + return + } + response.Success(c, updated) +} + +// ResetRuntimeLogConfig removes runtime override and falls back to env/yaml baseline. +// POST /api/v1/admin/ops/runtime/logging/reset +func (h *OpsHandler) ResetRuntimeLogConfig(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Error(c, http.StatusUnauthorized, "Unauthorized") + return + } + + updated, err := h.opsService.ResetRuntimeLogConfig(c.Request.Context(), subject.UserID) + if err != nil { + response.Error(c, http.StatusBadRequest, err.Error()) + return + } + response.Success(c, updated) +} + // GetAdvancedSettings returns Ops advanced settings (DB-backed). // GET /api/v1/admin/ops/advanced-settings func (h *OpsHandler) GetAdvancedSettings(c *gin.Context) { diff --git a/backend/internal/handler/admin/ops_snapshot_v2_handler.go b/backend/internal/handler/admin/ops_snapshot_v2_handler.go new file mode 100644 index 00000000..5cac00fe --- /dev/null +++ b/backend/internal/handler/admin/ops_snapshot_v2_handler.go @@ -0,0 +1,145 @@ +package admin + +import ( + "encoding/json" + "net/http" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "golang.org/x/sync/errgroup" +) + +var opsDashboardSnapshotV2Cache = newSnapshotCache(30 * time.Second) + +type opsDashboardSnapshotV2Response struct { + GeneratedAt string `json:"generated_at"` + + Overview *service.OpsDashboardOverview `json:"overview"` + ThroughputTrend *service.OpsThroughputTrendResponse `json:"throughput_trend"` + ErrorTrend *service.OpsErrorTrendResponse `json:"error_trend"` +} + +type opsDashboardSnapshotV2CacheKey struct { + StartTime string `json:"start_time"` + EndTime string `json:"end_time"` + Platform string `json:"platform"` + GroupID *int64 `json:"group_id"` + QueryMode service.OpsQueryMode `json:"mode"` + BucketSecond int `json:"bucket_second"` +} + +// GetDashboardSnapshotV2 returns ops dashboard core snapshot in one request. +// GET /api/v1/admin/ops/dashboard/snapshot-v2 +func (h *OpsHandler) GetDashboardSnapshotV2(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + startTime, endTime, err := parseOpsTimeRange(c, "1h") + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + filter := &service.OpsDashboardFilter{ + StartTime: startTime, + EndTime: endTime, + Platform: strings.TrimSpace(c.Query("platform")), + QueryMode: parseOpsQueryMode(c), + } + if v := strings.TrimSpace(c.Query("group_id")); v != "" { + id, err := strconv.ParseInt(v, 10, 64) + if err != nil || id <= 0 { + response.BadRequest(c, "Invalid group_id") + return + } + filter.GroupID = &id + } + bucketSeconds := pickThroughputBucketSeconds(endTime.Sub(startTime)) + + keyRaw, _ := json.Marshal(opsDashboardSnapshotV2CacheKey{ + StartTime: startTime.UTC().Format(time.RFC3339), + EndTime: endTime.UTC().Format(time.RFC3339), + Platform: filter.Platform, + GroupID: filter.GroupID, + QueryMode: filter.QueryMode, + BucketSecond: bucketSeconds, + }) + cacheKey := string(keyRaw) + + if cached, ok := opsDashboardSnapshotV2Cache.Get(cacheKey); ok { + if cached.ETag != "" { + c.Header("ETag", cached.ETag) + c.Header("Vary", "If-None-Match") + if ifNoneMatchMatched(c.GetHeader("If-None-Match"), cached.ETag) { + c.Status(http.StatusNotModified) + return + } + } + c.Header("X-Snapshot-Cache", "hit") + response.Success(c, cached.Payload) + return + } + + var ( + overview *service.OpsDashboardOverview + trend *service.OpsThroughputTrendResponse + errTrend *service.OpsErrorTrendResponse + ) + g, gctx := errgroup.WithContext(c.Request.Context()) + g.Go(func() error { + f := *filter + result, err := h.opsService.GetDashboardOverview(gctx, &f) + if err != nil { + return err + } + overview = result + return nil + }) + g.Go(func() error { + f := *filter + result, err := h.opsService.GetThroughputTrend(gctx, &f, bucketSeconds) + if err != nil { + return err + } + trend = result + return nil + }) + g.Go(func() error { + f := *filter + result, err := h.opsService.GetErrorTrend(gctx, &f, bucketSeconds) + if err != nil { + return err + } + errTrend = result + return nil + }) + if err := g.Wait(); err != nil { + response.ErrorFrom(c, err) + return + } + + resp := &opsDashboardSnapshotV2Response{ + GeneratedAt: time.Now().UTC().Format(time.RFC3339), + Overview: overview, + ThroughputTrend: trend, + ErrorTrend: errTrend, + } + + cached := opsDashboardSnapshotV2Cache.Set(cacheKey, resp) + if cached.ETag != "" { + c.Header("ETag", cached.ETag) + c.Header("Vary", "If-None-Match") + } + c.Header("X-Snapshot-Cache", "miss") + response.Success(c, resp) +} diff --git a/backend/internal/handler/admin/ops_system_log_handler.go b/backend/internal/handler/admin/ops_system_log_handler.go new file mode 100644 index 00000000..31fd51eb --- /dev/null +++ b/backend/internal/handler/admin/ops_system_log_handler.go @@ -0,0 +1,174 @@ +package admin + +import ( + "net/http" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +type opsSystemLogCleanupRequest struct { + StartTime string `json:"start_time"` + EndTime string `json:"end_time"` + + Level string `json:"level"` + Component string `json:"component"` + RequestID string `json:"request_id"` + ClientRequestID string `json:"client_request_id"` + UserID *int64 `json:"user_id"` + AccountID *int64 `json:"account_id"` + Platform string `json:"platform"` + Model string `json:"model"` + Query string `json:"q"` +} + +// ListSystemLogs returns indexed system logs. +// GET /api/v1/admin/ops/system-logs +func (h *OpsHandler) ListSystemLogs(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + page, pageSize := response.ParsePagination(c) + if pageSize > 200 { + pageSize = 200 + } + + start, end, err := parseOpsTimeRange(c, "1h") + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + filter := &service.OpsSystemLogFilter{ + Page: page, + PageSize: pageSize, + StartTime: &start, + EndTime: &end, + Level: strings.TrimSpace(c.Query("level")), + Component: strings.TrimSpace(c.Query("component")), + RequestID: strings.TrimSpace(c.Query("request_id")), + ClientRequestID: strings.TrimSpace(c.Query("client_request_id")), + Platform: strings.TrimSpace(c.Query("platform")), + Model: strings.TrimSpace(c.Query("model")), + Query: strings.TrimSpace(c.Query("q")), + } + if v := strings.TrimSpace(c.Query("user_id")); v != "" { + id, parseErr := strconv.ParseInt(v, 10, 64) + if parseErr != nil || id <= 0 { + response.BadRequest(c, "Invalid user_id") + return + } + filter.UserID = &id + } + if v := strings.TrimSpace(c.Query("account_id")); v != "" { + id, parseErr := strconv.ParseInt(v, 10, 64) + if parseErr != nil || id <= 0 { + response.BadRequest(c, "Invalid account_id") + return + } + filter.AccountID = &id + } + + result, err := h.opsService.ListSystemLogs(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Paginated(c, result.Logs, int64(result.Total), result.Page, result.PageSize) +} + +// CleanupSystemLogs deletes indexed system logs by filter. +// POST /api/v1/admin/ops/system-logs/cleanup +func (h *OpsHandler) CleanupSystemLogs(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Error(c, http.StatusUnauthorized, "Unauthorized") + return + } + + var req opsSystemLogCleanupRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request body") + return + } + + parseTS := func(raw string) (*time.Time, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil, nil + } + if t, err := time.Parse(time.RFC3339Nano, raw); err == nil { + return &t, nil + } + t, err := time.Parse(time.RFC3339, raw) + if err != nil { + return nil, err + } + return &t, nil + } + start, err := parseTS(req.StartTime) + if err != nil { + response.BadRequest(c, "Invalid start_time") + return + } + end, err := parseTS(req.EndTime) + if err != nil { + response.BadRequest(c, "Invalid end_time") + return + } + + filter := &service.OpsSystemLogCleanupFilter{ + StartTime: start, + EndTime: end, + Level: strings.TrimSpace(req.Level), + Component: strings.TrimSpace(req.Component), + RequestID: strings.TrimSpace(req.RequestID), + ClientRequestID: strings.TrimSpace(req.ClientRequestID), + UserID: req.UserID, + AccountID: req.AccountID, + Platform: strings.TrimSpace(req.Platform), + Model: strings.TrimSpace(req.Model), + Query: strings.TrimSpace(req.Query), + } + + deleted, err := h.opsService.CleanupSystemLogs(c.Request.Context(), filter, subject.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"deleted": deleted}) +} + +// GetSystemLogIngestionHealth returns sink health metrics. +// GET /api/v1/admin/ops/system-logs/health +func (h *OpsHandler) GetSystemLogIngestionHealth(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, h.opsService.GetSystemLogSinkHealth()) +} diff --git a/backend/internal/handler/admin/ops_system_log_handler_test.go b/backend/internal/handler/admin/ops_system_log_handler_test.go new file mode 100644 index 00000000..7528acd8 --- /dev/null +++ b/backend/internal/handler/admin/ops_system_log_handler_test.go @@ -0,0 +1,233 @@ +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +type responseEnvelope struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` +} + +func newOpsSystemLogTestRouter(handler *OpsHandler, withUser bool) *gin.Engine { + gin.SetMode(gin.TestMode) + r := gin.New() + if withUser { + r.Use(func(c *gin.Context) { + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 99}) + c.Next() + }) + } + r.GET("/logs", handler.ListSystemLogs) + r.POST("/logs/cleanup", handler.CleanupSystemLogs) + r.GET("/logs/health", handler.GetSystemLogIngestionHealth) + return r +} + +func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) { + h := NewOpsHandler(nil) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusServiceUnavailable { + t.Fatalf("status=%d, want 503", w.Code) + } +} + +func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs?user_id=abc", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Fatalf("status=%d, want 400", w.Code) + } +} + +func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs?account_id=-1", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Fatalf("status=%d, want 400", w.Code) + } +} + +func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) { + svc := service.NewOpsService(nil, nil, &config.Config{ + Ops: config.OpsConfig{Enabled: false}, + }, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusNotFound { + t.Fatalf("status=%d, want 404", w.Code) + } +} + +func TestOpsSystemLogHandler_ListSuccess(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs?time_range=30m&page=1&page_size=20", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("status=%d, want 200", w.Code) + } + + var resp responseEnvelope + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if resp.Code != 0 { + t.Fatalf("unexpected response code: %+v", resp) + } +} + +func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Fatalf("status=%d, want 401", w.Code) + } +} + +func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, true) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{bad-json`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Fatalf("status=%d, want 400", w.Code) + } +} + +func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, true) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"start_time":"bad","request_id":"r1"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Fatalf("status=%d, want 400", w.Code) + } +} + +func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, true) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"end_time":"bad","request_id":"r1"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Fatalf("status=%d, want 400", w.Code) + } +} + +func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, true) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusServiceUnavailable { + t.Fatalf("status=%d, want 503", w.Code) + } +} + +func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) { + svc := service.NewOpsService(nil, nil, &config.Config{ + Ops: config.OpsConfig{Enabled: false}, + }, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, true) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusNotFound { + t.Fatalf("status=%d, want 404", w.Code) + } +} + +func TestOpsSystemLogHandler_Health(t *testing.T) { + sink := service.NewOpsSystemLogSink(nil) + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs/health", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("status=%d, want 200", w.Code) + } +} + +func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) { + h := NewOpsHandler(nil) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs/health", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusServiceUnavailable { + t.Fatalf("status=%d, want 503", w.Code) + } + + svc := service.NewOpsService(nil, nil, &config.Config{ + Ops: config.OpsConfig{Enabled: false}, + }, nil, nil, nil, nil, nil, nil, nil, nil) + h = NewOpsHandler(svc) + r = newOpsSystemLogTestRouter(h, false) + w = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/logs/health", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusNotFound { + t.Fatalf("status=%d, want 404", w.Code) + } +} diff --git a/backend/internal/handler/admin/ops_ws_handler.go b/backend/internal/handler/admin/ops_ws_handler.go index db7442e5..75fd7ea0 100644 --- a/backend/internal/handler/admin/ops_ws_handler.go +++ b/backend/internal/handler/admin/ops_ws_handler.go @@ -3,7 +3,6 @@ package admin import ( "context" "encoding/json" - "log" "math" "net" "net/http" @@ -16,6 +15,7 @@ import ( "sync/atomic" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" @@ -62,7 +62,8 @@ const ( ) var wsConnCount atomic.Int32 -var wsConnCountByIP sync.Map // map[string]*atomic.Int32 +var wsConnCountByIPMu sync.Mutex +var wsConnCountByIP = make(map[string]int32) const qpsWSIdleStopDelay = 30 * time.Second @@ -252,7 +253,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) { stats, err := opsService.GetWindowStats(ctx, now.Add(-c.requestCountWindow), now) if err != nil || stats == nil { if err != nil { - log.Printf("[OpsWS] refresh: get window stats failed: %v", err) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: get window stats failed: %v", err) } return } @@ -278,7 +279,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) { msg, err := json.Marshal(payload) if err != nil { - log.Printf("[OpsWS] refresh: marshal payload failed: %v", err) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: marshal payload failed: %v", err) return } @@ -338,7 +339,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) { // Reserve a global slot before upgrading the connection to keep the limit strict. if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) { - log.Printf("[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns) c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"}) return } @@ -350,7 +351,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) { if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" { if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) { - log.Printf("[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP) c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"}) return } @@ -359,7 +360,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) { conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { - log.Printf("[OpsWS] upgrade failed: %v", err) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] upgrade failed: %v", err) return } @@ -389,42 +390,31 @@ func tryAcquireOpsWSIPSlot(clientIP string, limit int32) bool { if strings.TrimSpace(clientIP) == "" || limit <= 0 { return true } - - v, _ := wsConnCountByIP.LoadOrStore(clientIP, &atomic.Int32{}) - counter, ok := v.(*atomic.Int32) - if !ok { + wsConnCountByIPMu.Lock() + defer wsConnCountByIPMu.Unlock() + current := wsConnCountByIP[clientIP] + if current >= limit { return false } - - for { - current := counter.Load() - if current >= limit { - return false - } - if counter.CompareAndSwap(current, current+1) { - return true - } - } + wsConnCountByIP[clientIP] = current + 1 + return true } func releaseOpsWSIPSlot(clientIP string) { if strings.TrimSpace(clientIP) == "" { return } - - v, ok := wsConnCountByIP.Load(clientIP) + wsConnCountByIPMu.Lock() + defer wsConnCountByIPMu.Unlock() + current, ok := wsConnCountByIP[clientIP] if !ok { return } - counter, ok := v.(*atomic.Int32) - if !ok { + if current <= 1 { + delete(wsConnCountByIP, clientIP) return } - next := counter.Add(-1) - if next <= 0 { - // Best-effort cleanup; safe even if a new slot was acquired concurrently. - wsConnCountByIP.Delete(clientIP) - } + wsConnCountByIP[clientIP] = current - 1 } func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { @@ -452,7 +442,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { conn.SetReadLimit(qpsWSMaxReadBytes) if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil { - log.Printf("[OpsWS] set read deadline failed: %v", err) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] set read deadline failed: %v", err) return } conn.SetPongHandler(func(string) error { @@ -471,7 +461,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { _, _, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { - log.Printf("[OpsWS] read failed: %v", err) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] read failed: %v", err) } return } @@ -508,7 +498,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { continue } if err := writeWithTimeout(websocket.TextMessage, msg); err != nil { - log.Printf("[OpsWS] write failed: %v", err) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] write failed: %v", err) cancel() closeConn() wg.Wait() @@ -517,7 +507,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { case <-pingTicker.C: if err := writeWithTimeout(websocket.PingMessage, nil); err != nil { - log.Printf("[OpsWS] ping failed: %v", err) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] ping failed: %v", err) cancel() closeConn() wg.Wait() @@ -666,14 +656,14 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig { if parsed, err := strconv.ParseBool(v); err == nil { cfg.TrustProxy = parsed } else { - log.Printf("[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy) } } if raw := strings.TrimSpace(os.Getenv(envOpsWSTrustedProxies)); raw != "" { prefixes, invalid := parseTrustedProxyList(raw) if len(invalid) > 0 { - log.Printf("[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", ")) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", ")) } cfg.TrustedProxies = prefixes } @@ -684,7 +674,7 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig { case OriginPolicyStrict, OriginPolicyPermissive: cfg.OriginPolicy = normalized default: - log.Printf("[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy) } } @@ -701,14 +691,14 @@ func loadOpsWSRuntimeLimitsFromEnv() opsWSRuntimeLimits { if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 { cfg.MaxConns = int32(parsed) } else { - log.Printf("[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns) } } if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConnsPerIP)); v != "" { if parsed, err := strconv.Atoi(v); err == nil && parsed >= 0 { cfg.MaxConnsPerIP = int32(parsed) } else { - log.Printf("[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP) } } return cfg diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go index a6758f69..e8ae0ce2 100644 --- a/backend/internal/handler/admin/proxy_handler.go +++ b/backend/internal/handler/admin/proxy_handler.go @@ -1,6 +1,7 @@ package admin import ( + "context" "strconv" "strings" @@ -63,9 +64,9 @@ func (h *ProxyHandler) List(c *gin.Context) { return } - out := make([]dto.ProxyWithAccountCount, 0, len(proxies)) + out := make([]dto.AdminProxyWithAccountCount, 0, len(proxies)) for i := range proxies { - out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i])) + out = append(out, *dto.ProxyWithAccountCountFromServiceAdmin(&proxies[i])) } response.Paginated(c, out, total, page, pageSize) } @@ -82,9 +83,9 @@ func (h *ProxyHandler) GetAll(c *gin.Context) { response.ErrorFrom(c, err) return } - out := make([]dto.ProxyWithAccountCount, 0, len(proxies)) + out := make([]dto.AdminProxyWithAccountCount, 0, len(proxies)) for i := range proxies { - out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i])) + out = append(out, *dto.ProxyWithAccountCountFromServiceAdmin(&proxies[i])) } response.Success(c, out) return @@ -96,9 +97,9 @@ func (h *ProxyHandler) GetAll(c *gin.Context) { return } - out := make([]dto.Proxy, 0, len(proxies)) + out := make([]dto.AdminProxy, 0, len(proxies)) for i := range proxies { - out = append(out, *dto.ProxyFromService(&proxies[i])) + out = append(out, *dto.ProxyFromServiceAdmin(&proxies[i])) } response.Success(c, out) } @@ -118,7 +119,7 @@ func (h *ProxyHandler) GetByID(c *gin.Context) { return } - response.Success(c, dto.ProxyFromService(proxy)) + response.Success(c, dto.ProxyFromServiceAdmin(proxy)) } // Create handles creating a new proxy @@ -130,20 +131,20 @@ func (h *ProxyHandler) Create(c *gin.Context) { return } - proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ - Name: strings.TrimSpace(req.Name), - Protocol: strings.TrimSpace(req.Protocol), - Host: strings.TrimSpace(req.Host), - Port: req.Port, - Username: strings.TrimSpace(req.Username), - Password: strings.TrimSpace(req.Password), + executeAdminIdempotentJSON(c, "admin.proxies.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + proxy, err := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{ + Name: strings.TrimSpace(req.Name), + Protocol: strings.TrimSpace(req.Protocol), + Host: strings.TrimSpace(req.Host), + Port: req.Port, + Username: strings.TrimSpace(req.Username), + Password: strings.TrimSpace(req.Password), + }) + if err != nil { + return nil, err + } + return dto.ProxyFromServiceAdmin(proxy), nil }) - if err != nil { - response.ErrorFrom(c, err) - return - } - - response.Success(c, dto.ProxyFromService(proxy)) } // Update handles updating a proxy @@ -175,7 +176,7 @@ func (h *ProxyHandler) Update(c *gin.Context) { return } - response.Success(c, dto.ProxyFromService(proxy)) + response.Success(c, dto.ProxyFromServiceAdmin(proxy)) } // Delete handles deleting a proxy @@ -236,6 +237,24 @@ func (h *ProxyHandler) Test(c *gin.Context) { response.Success(c, result) } +// CheckQuality handles checking proxy quality across common AI targets. +// POST /api/v1/admin/proxies/:id/quality-check +func (h *ProxyHandler) CheckQuality(c *gin.Context) { + proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid proxy ID") + return + } + + result, err := h.adminService.CheckProxyQuality(c.Request.Context(), proxyID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, result) +} + // GetStats handles getting proxy statistics // GET /api/v1/admin/proxies/:id/stats func (h *ProxyHandler) GetStats(c *gin.Context) { diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go index 02752fea..0a932ee9 100644 --- a/backend/internal/handler/admin/redeem_handler.go +++ b/backend/internal/handler/admin/redeem_handler.go @@ -2,12 +2,15 @@ package admin import ( "bytes" + "context" "encoding/csv" + "errors" "fmt" "strconv" "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -16,13 +19,15 @@ import ( // RedeemHandler handles admin redeem code management type RedeemHandler struct { - adminService service.AdminService + adminService service.AdminService + redeemService *service.RedeemService } // NewRedeemHandler creates a new admin redeem handler -func NewRedeemHandler(adminService service.AdminService) *RedeemHandler { +func NewRedeemHandler(adminService service.AdminService, redeemService *service.RedeemService) *RedeemHandler { return &RedeemHandler{ - adminService: adminService, + adminService: adminService, + redeemService: redeemService, } } @@ -35,6 +40,15 @@ type GenerateRedeemCodesRequest struct { ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用,默认30天,最大100年 } +// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user. +type CreateAndRedeemCodeRequest struct { + Code string `json:"code" binding:"required,min=3,max=128"` + Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"` + Value float64 `json:"value" binding:"required,gt=0"` + UserID int64 `json:"user_id" binding:"required,gt=0"` + Notes string `json:"notes"` +} + // List handles listing all redeem codes with pagination // GET /api/v1/admin/redeem-codes func (h *RedeemHandler) List(c *gin.Context) { @@ -88,23 +102,99 @@ func (h *RedeemHandler) Generate(c *gin.Context) { return } - codes, err := h.adminService.GenerateRedeemCodes(c.Request.Context(), &service.GenerateRedeemCodesInput{ - Count: req.Count, - Type: req.Type, - Value: req.Value, - GroupID: req.GroupID, - ValidityDays: req.ValidityDays, + executeAdminIdempotentJSON(c, "admin.redeem_codes.generate", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + codes, execErr := h.adminService.GenerateRedeemCodes(ctx, &service.GenerateRedeemCodesInput{ + Count: req.Count, + Type: req.Type, + Value: req.Value, + GroupID: req.GroupID, + ValidityDays: req.ValidityDays, + }) + if execErr != nil { + return nil, execErr + } + + out := make([]dto.AdminRedeemCode, 0, len(codes)) + for i := range codes { + out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i])) + } + return out, nil }) - if err != nil { - response.ErrorFrom(c, err) +} + +// CreateAndRedeem creates a fixed redeem code and redeems it for a target user in one step. +// POST /api/v1/admin/redeem-codes/create-and-redeem +func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) { + if h.redeemService == nil { + response.InternalError(c, "redeem service not configured") return } - out := make([]dto.AdminRedeemCode, 0, len(codes)) - for i := range codes { - out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i])) + var req CreateAndRedeemCodeRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return } - response.Success(c, out) + req.Code = strings.TrimSpace(req.Code) + + executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + existing, err := h.redeemService.GetByCode(ctx, req.Code) + if err == nil { + return h.resolveCreateAndRedeemExisting(ctx, existing, req.UserID) + } + if !errors.Is(err, service.ErrRedeemCodeNotFound) { + return nil, err + } + + createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{ + Code: req.Code, + Type: req.Type, + Value: req.Value, + Status: service.StatusUnused, + Notes: req.Notes, + }) + if createErr != nil { + // Unique code race: if code now exists, use idempotent semantics by used_by. + existingAfterCreateErr, getErr := h.redeemService.GetByCode(ctx, req.Code) + if getErr == nil { + return h.resolveCreateAndRedeemExisting(ctx, existingAfterCreateErr, req.UserID) + } + return nil, createErr + } + + redeemed, redeemErr := h.redeemService.Redeem(ctx, req.UserID, req.Code) + if redeemErr != nil { + return nil, redeemErr + } + return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(redeemed)}, nil + }) +} + +func (h *RedeemHandler) resolveCreateAndRedeemExisting(ctx context.Context, existing *service.RedeemCode, userID int64) (any, error) { + if existing == nil { + return nil, infraerrors.Conflict("REDEEM_CODE_CONFLICT", "redeem code conflict") + } + + // If previous run created the code but crashed before redeem, redeem it now. + if existing.CanUse() { + redeemed, err := h.redeemService.Redeem(ctx, userID, existing.Code) + if err == nil { + return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(redeemed)}, nil + } + if !errors.Is(err, service.ErrRedeemCodeUsed) { + return nil, err + } + latest, getErr := h.redeemService.GetByCode(ctx, existing.Code) + if getErr == nil { + existing = latest + } + } + + if existing.UsedBy != nil && *existing.UsedBy == userID { + return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(existing)}, nil + } + + return nil, infraerrors.Conflict("REDEEM_CODE_CONFLICT", "redeem code already used by another user") } // Delete handles deleting a redeem code diff --git a/backend/internal/handler/admin/search_truncate_test.go b/backend/internal/handler/admin/search_truncate_test.go new file mode 100644 index 00000000..ffd60e2a --- /dev/null +++ b/backend/internal/handler/admin/search_truncate_test.go @@ -0,0 +1,97 @@ +//go:build unit + +package admin + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// truncateSearchByRune 模拟 user_handler.go 中的 search 截断逻辑 +func truncateSearchByRune(search string, maxRunes int) string { + if runes := []rune(search); len(runes) > maxRunes { + return string(runes[:maxRunes]) + } + return search +} + +func TestTruncateSearchByRune(t *testing.T) { + tests := []struct { + name string + input string + maxRunes int + wantLen int // 期望的 rune 长度 + }{ + { + name: "纯中文超长", + input: string(make([]rune, 150)), + maxRunes: 100, + wantLen: 100, + }, + { + name: "纯 ASCII 超长", + input: string(make([]byte, 150)), + maxRunes: 100, + wantLen: 100, + }, + { + name: "空字符串", + input: "", + maxRunes: 100, + wantLen: 0, + }, + { + name: "恰好 100 个字符", + input: string(make([]rune, 100)), + maxRunes: 100, + wantLen: 100, + }, + { + name: "不足 100 字符不截断", + input: "hello世界", + maxRunes: 100, + wantLen: 7, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := truncateSearchByRune(tc.input, tc.maxRunes) + require.Equal(t, tc.wantLen, len([]rune(result))) + }) + } +} + +func TestTruncateSearchByRune_PreservesMultibyte(t *testing.T) { + // 101 个中文字符,截断到 100 个后应该仍然是有效 UTF-8 + input := "" + for i := 0; i < 101; i++ { + input += "中" + } + result := truncateSearchByRune(input, 100) + + require.Equal(t, 100, len([]rune(result))) + // 验证截断结果是有效的 UTF-8(每个中文字符 3 字节) + require.Equal(t, 300, len(result)) +} + +func TestTruncateSearchByRune_MixedASCIIAndMultibyte(t *testing.T) { + // 50 个 ASCII + 51 个中文 = 101 个 rune + input := "" + for i := 0; i < 50; i++ { + input += "a" + } + for i := 0; i < 51; i++ { + input += "中" + } + result := truncateSearchByRune(input, 100) + + runes := []rune(result) + require.Equal(t, 100, len(runes)) + // 前 50 个应该是 'a',后 50 个应该是 '中' + require.Equal(t, 'a', runes[0]) + require.Equal(t, 'a', runes[49]) + require.Equal(t, '中', runes[50]) + require.Equal(t, '中', runes[99]) +} diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 1e723ee5..04292088 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -1,7 +1,13 @@ package admin import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" "log" + "net/http" + "regexp" "strings" "time" @@ -14,21 +20,38 @@ import ( "github.com/gin-gonic/gin" ) +// semverPattern 预编译 semver 格式校验正则 +var semverPattern = regexp.MustCompile(`^\d+\.\d+\.\d+$`) + +// menuItemIDPattern validates custom menu item IDs: alphanumeric, hyphens, underscores only. +var menuItemIDPattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) + +// generateMenuItemID generates a short random hex ID for a custom menu item. +func generateMenuItemID() (string, error) { + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("generate menu item ID: %w", err) + } + return hex.EncodeToString(b), nil +} + // SettingHandler 系统设置处理器 type SettingHandler struct { settingService *service.SettingService emailService *service.EmailService turnstileService *service.TurnstileService opsService *service.OpsService + soraS3Storage *service.SoraS3Storage } // NewSettingHandler 创建系统设置处理器 -func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService) *SettingHandler { +func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, soraS3Storage *service.SoraS3Storage) *SettingHandler { return &SettingHandler{ settingService: settingService, emailService: emailService, turnstileService: turnstileService, opsService: opsService, + soraS3Storage: soraS3Storage, } } @@ -43,10 +66,18 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { // Check if ops monitoring is enabled (respects config.ops.enabled) opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context()) + defaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(settings.DefaultSubscriptions)) + for _, sub := range settings.DefaultSubscriptions { + defaultSubscriptions = append(defaultSubscriptions, dto.DefaultSubscriptionSetting{ + GroupID: sub.GroupID, + ValidityDays: sub.ValidityDays, + }) + } response.Success(c, dto.SystemSettings{ RegistrationEnabled: settings.RegistrationEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled, + RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, PromoCodeEnabled: settings.PromoCodeEnabled, PasswordResetEnabled: settings.PasswordResetEnabled, InvitationCodeEnabled: settings.InvitationCodeEnabled, @@ -76,8 +107,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { HideCcsImportButton: settings.HideCcsImportButton, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, + SoraClientEnabled: settings.SoraClientEnabled, + CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems), DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, + DefaultSubscriptions: defaultSubscriptions, EnableModelFallback: settings.EnableModelFallback, FallbackModelAnthropic: settings.FallbackModelAnthropic, FallbackModelOpenAI: settings.FallbackModelOpenAI, @@ -89,18 +123,21 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled, OpsQueryModeDefault: settings.OpsQueryModeDefault, OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds, + MinClaudeCodeVersion: settings.MinClaudeCodeVersion, + AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling, }) } // UpdateSettingsRequest 更新设置请求 type UpdateSettingsRequest struct { // 注册设置 - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 // 邮件服务设置 SMTPHost string `json:"smtp_host"` @@ -123,20 +160,23 @@ type UpdateSettingsRequest struct { LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` // OEM设置 - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo"` - SiteSubtitle string `json:"site_subtitle"` - APIBaseURL string `json:"api_base_url"` - ContactInfo string `json:"contact_info"` - DocURL string `json:"doc_url"` - HomeContent string `json:"home_content"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + APIBaseURL string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocURL string `json:"doc_url"` + HomeContent string `json:"home_content"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` + SoraClientEnabled bool `json:"sora_client_enabled"` + CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"` // 默认配置 - DefaultConcurrency int `json:"default_concurrency"` - DefaultBalance float64 `json:"default_balance"` + DefaultConcurrency int `json:"default_concurrency"` + DefaultBalance float64 `json:"default_balance"` + DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` @@ -154,6 +194,11 @@ type UpdateSettingsRequest struct { OpsRealtimeMonitoringEnabled *bool `json:"ops_realtime_monitoring_enabled"` OpsQueryModeDefault *string `json:"ops_query_mode_default"` OpsMetricsIntervalSeconds *int `json:"ops_metrics_interval_seconds"` + + MinClaudeCodeVersion string `json:"min_claude_code_version"` + + // 分组隔离 + AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"` } // UpdateSettings 更新系统设置 @@ -181,6 +226,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { if req.SMTPPort <= 0 { req.SMTPPort = 587 } + req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions) // Turnstile 参数验证 if req.TurnstileEnabled { @@ -276,6 +322,84 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } } + // 自定义菜单项验证 + const ( + maxCustomMenuItems = 20 + maxMenuItemLabelLen = 50 + maxMenuItemURLLen = 2048 + maxMenuItemIconSVGLen = 10 * 1024 // 10KB + maxMenuItemIDLen = 32 + ) + + customMenuJSON := previousSettings.CustomMenuItems + if req.CustomMenuItems != nil { + items := *req.CustomMenuItems + if len(items) > maxCustomMenuItems { + response.BadRequest(c, "Too many custom menu items (max 20)") + return + } + for i, item := range items { + if strings.TrimSpace(item.Label) == "" { + response.BadRequest(c, "Custom menu item label is required") + return + } + if len(item.Label) > maxMenuItemLabelLen { + response.BadRequest(c, "Custom menu item label is too long (max 50 characters)") + return + } + if strings.TrimSpace(item.URL) == "" { + response.BadRequest(c, "Custom menu item URL is required") + return + } + if len(item.URL) > maxMenuItemURLLen { + response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)") + return + } + if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(item.URL)); err != nil { + response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL") + return + } + if item.Visibility != "user" && item.Visibility != "admin" { + response.BadRequest(c, "Custom menu item visibility must be 'user' or 'admin'") + return + } + if len(item.IconSVG) > maxMenuItemIconSVGLen { + response.BadRequest(c, "Custom menu item icon SVG is too large (max 10KB)") + return + } + // Auto-generate ID if missing + if strings.TrimSpace(item.ID) == "" { + id, err := generateMenuItemID() + if err != nil { + response.Error(c, http.StatusInternalServerError, "Failed to generate menu item ID") + return + } + items[i].ID = id + } else if len(item.ID) > maxMenuItemIDLen { + response.BadRequest(c, "Custom menu item ID is too long (max 32 characters)") + return + } else if !menuItemIDPattern.MatchString(item.ID) { + response.BadRequest(c, "Custom menu item ID contains invalid characters (only a-z, A-Z, 0-9, - and _ are allowed)") + return + } + } + // ID uniqueness check + seen := make(map[string]struct{}, len(items)) + for _, item := range items { + if _, exists := seen[item.ID]; exists { + response.BadRequest(c, "Duplicate custom menu item ID: "+item.ID) + return + } + seen[item.ID] = struct{}{} + } + menuBytes, err := json.Marshal(items) + if err != nil { + response.BadRequest(c, "Failed to serialize custom menu items") + return + } + customMenuJSON = string(menuBytes) + } + // Ops metrics collector interval validation (seconds). if req.OpsMetricsIntervalSeconds != nil { v := *req.OpsMetricsIntervalSeconds @@ -287,47 +411,68 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } req.OpsMetricsIntervalSeconds = &v } + defaultSubscriptions := make([]service.DefaultSubscriptionSetting, 0, len(req.DefaultSubscriptions)) + for _, sub := range req.DefaultSubscriptions { + defaultSubscriptions = append(defaultSubscriptions, service.DefaultSubscriptionSetting{ + GroupID: sub.GroupID, + ValidityDays: sub.ValidityDays, + }) + } + + // 验证最低版本号格式(空字符串=禁用,或合法 semver) + if req.MinClaudeCodeVersion != "" { + if !semverPattern.MatchString(req.MinClaudeCodeVersion) { + response.Error(c, http.StatusBadRequest, "min_claude_code_version must be empty or a valid semver (e.g. 2.1.63)") + return + } + } settings := &service.SystemSettings{ - RegistrationEnabled: req.RegistrationEnabled, - EmailVerifyEnabled: req.EmailVerifyEnabled, - PromoCodeEnabled: req.PromoCodeEnabled, - PasswordResetEnabled: req.PasswordResetEnabled, - InvitationCodeEnabled: req.InvitationCodeEnabled, - TotpEnabled: req.TotpEnabled, - SMTPHost: req.SMTPHost, - SMTPPort: req.SMTPPort, - SMTPUsername: req.SMTPUsername, - SMTPPassword: req.SMTPPassword, - SMTPFrom: req.SMTPFrom, - SMTPFromName: req.SMTPFromName, - SMTPUseTLS: req.SMTPUseTLS, - TurnstileEnabled: req.TurnstileEnabled, - TurnstileSiteKey: req.TurnstileSiteKey, - TurnstileSecretKey: req.TurnstileSecretKey, - LinuxDoConnectEnabled: req.LinuxDoConnectEnabled, - LinuxDoConnectClientID: req.LinuxDoConnectClientID, - LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, - LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, - SiteName: req.SiteName, - SiteLogo: req.SiteLogo, - SiteSubtitle: req.SiteSubtitle, - APIBaseURL: req.APIBaseURL, - ContactInfo: req.ContactInfo, - DocURL: req.DocURL, - HomeContent: req.HomeContent, - HideCcsImportButton: req.HideCcsImportButton, - PurchaseSubscriptionEnabled: purchaseEnabled, - PurchaseSubscriptionURL: purchaseURL, - DefaultConcurrency: req.DefaultConcurrency, - DefaultBalance: req.DefaultBalance, - EnableModelFallback: req.EnableModelFallback, - FallbackModelAnthropic: req.FallbackModelAnthropic, - FallbackModelOpenAI: req.FallbackModelOpenAI, - FallbackModelGemini: req.FallbackModelGemini, - FallbackModelAntigravity: req.FallbackModelAntigravity, - EnableIdentityPatch: req.EnableIdentityPatch, - IdentityPatchPrompt: req.IdentityPatchPrompt, + RegistrationEnabled: req.RegistrationEnabled, + EmailVerifyEnabled: req.EmailVerifyEnabled, + RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist, + PromoCodeEnabled: req.PromoCodeEnabled, + PasswordResetEnabled: req.PasswordResetEnabled, + InvitationCodeEnabled: req.InvitationCodeEnabled, + TotpEnabled: req.TotpEnabled, + SMTPHost: req.SMTPHost, + SMTPPort: req.SMTPPort, + SMTPUsername: req.SMTPUsername, + SMTPPassword: req.SMTPPassword, + SMTPFrom: req.SMTPFrom, + SMTPFromName: req.SMTPFromName, + SMTPUseTLS: req.SMTPUseTLS, + TurnstileEnabled: req.TurnstileEnabled, + TurnstileSiteKey: req.TurnstileSiteKey, + TurnstileSecretKey: req.TurnstileSecretKey, + LinuxDoConnectEnabled: req.LinuxDoConnectEnabled, + LinuxDoConnectClientID: req.LinuxDoConnectClientID, + LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, + LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, + SiteName: req.SiteName, + SiteLogo: req.SiteLogo, + SiteSubtitle: req.SiteSubtitle, + APIBaseURL: req.APIBaseURL, + ContactInfo: req.ContactInfo, + DocURL: req.DocURL, + HomeContent: req.HomeContent, + HideCcsImportButton: req.HideCcsImportButton, + PurchaseSubscriptionEnabled: purchaseEnabled, + PurchaseSubscriptionURL: purchaseURL, + SoraClientEnabled: req.SoraClientEnabled, + CustomMenuItems: customMenuJSON, + DefaultConcurrency: req.DefaultConcurrency, + DefaultBalance: req.DefaultBalance, + DefaultSubscriptions: defaultSubscriptions, + EnableModelFallback: req.EnableModelFallback, + FallbackModelAnthropic: req.FallbackModelAnthropic, + FallbackModelOpenAI: req.FallbackModelOpenAI, + FallbackModelGemini: req.FallbackModelGemini, + FallbackModelAntigravity: req.FallbackModelAntigravity, + EnableIdentityPatch: req.EnableIdentityPatch, + IdentityPatchPrompt: req.IdentityPatchPrompt, + MinClaudeCodeVersion: req.MinClaudeCodeVersion, + AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling, OpsMonitoringEnabled: func() bool { if req.OpsMonitoringEnabled != nil { return *req.OpsMonitoringEnabled @@ -367,10 +512,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions)) + for _, sub := range updatedSettings.DefaultSubscriptions { + updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{ + GroupID: sub.GroupID, + ValidityDays: sub.ValidityDays, + }) + } response.Success(c, dto.SystemSettings{ RegistrationEnabled: updatedSettings.RegistrationEnabled, EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, + RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist, PromoCodeEnabled: updatedSettings.PromoCodeEnabled, PasswordResetEnabled: updatedSettings.PasswordResetEnabled, InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled, @@ -400,8 +553,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { HideCcsImportButton: updatedSettings.HideCcsImportButton, PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, + SoraClientEnabled: updatedSettings.SoraClientEnabled, + CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems), DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultBalance: updatedSettings.DefaultBalance, + DefaultSubscriptions: updatedDefaultSubscriptions, EnableModelFallback: updatedSettings.EnableModelFallback, FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI, @@ -413,6 +569,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled, OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault, OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds, + MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion, + AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling, }) } @@ -444,6 +602,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.EmailVerifyEnabled != after.EmailVerifyEnabled { changed = append(changed, "email_verify_enabled") } + if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) { + changed = append(changed, "registration_email_suffix_whitelist") + } if before.PasswordResetEnabled != after.PasswordResetEnabled { changed = append(changed, "password_reset_enabled") } @@ -522,6 +683,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.DefaultBalance != after.DefaultBalance { changed = append(changed, "default_balance") } + if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) { + changed = append(changed, "default_subscriptions") + } if before.EnableModelFallback != after.EnableModelFallback { changed = append(changed, "enable_model_fallback") } @@ -555,9 +719,65 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.OpsMetricsIntervalSeconds != after.OpsMetricsIntervalSeconds { changed = append(changed, "ops_metrics_interval_seconds") } + if before.MinClaudeCodeVersion != after.MinClaudeCodeVersion { + changed = append(changed, "min_claude_code_version") + } + if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling { + changed = append(changed, "allow_ungrouped_key_scheduling") + } + if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled { + changed = append(changed, "purchase_subscription_enabled") + } + if before.PurchaseSubscriptionURL != after.PurchaseSubscriptionURL { + changed = append(changed, "purchase_subscription_url") + } + if before.CustomMenuItems != after.CustomMenuItems { + changed = append(changed, "custom_menu_items") + } return changed } +func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto.DefaultSubscriptionSetting { + if len(input) == 0 { + return nil + } + normalized := make([]dto.DefaultSubscriptionSetting, 0, len(input)) + for _, item := range input { + if item.GroupID <= 0 || item.ValidityDays <= 0 { + continue + } + if item.ValidityDays > service.MaxValidityDays { + item.ValidityDays = service.MaxValidityDays + } + normalized = append(normalized, item) + } + return normalized +} + +func equalStringSlice(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i].GroupID != b[i].GroupID || a[i].ValidityDays != b[i].ValidityDays { + return false + } + } + return true +} + // TestSMTPRequest 测试SMTP连接请求 type TestSMTPRequest struct { SMTPHost string `json:"smtp_host" binding:"required"` @@ -750,6 +970,384 @@ func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) { }) } +func toSoraS3SettingsDTO(settings *service.SoraS3Settings) dto.SoraS3Settings { + if settings == nil { + return dto.SoraS3Settings{} + } + return dto.SoraS3Settings{ + Enabled: settings.Enabled, + Endpoint: settings.Endpoint, + Region: settings.Region, + Bucket: settings.Bucket, + AccessKeyID: settings.AccessKeyID, + SecretAccessKeyConfigured: settings.SecretAccessKeyConfigured, + Prefix: settings.Prefix, + ForcePathStyle: settings.ForcePathStyle, + CDNURL: settings.CDNURL, + DefaultStorageQuotaBytes: settings.DefaultStorageQuotaBytes, + } +} + +func toSoraS3ProfileDTO(profile service.SoraS3Profile) dto.SoraS3Profile { + return dto.SoraS3Profile{ + ProfileID: profile.ProfileID, + Name: profile.Name, + IsActive: profile.IsActive, + Enabled: profile.Enabled, + Endpoint: profile.Endpoint, + Region: profile.Region, + Bucket: profile.Bucket, + AccessKeyID: profile.AccessKeyID, + SecretAccessKeyConfigured: profile.SecretAccessKeyConfigured, + Prefix: profile.Prefix, + ForcePathStyle: profile.ForcePathStyle, + CDNURL: profile.CDNURL, + DefaultStorageQuotaBytes: profile.DefaultStorageQuotaBytes, + UpdatedAt: profile.UpdatedAt, + } +} + +func validateSoraS3RequiredWhenEnabled(enabled bool, endpoint, bucket, accessKeyID, secretAccessKey string, hasStoredSecret bool) error { + if !enabled { + return nil + } + if strings.TrimSpace(endpoint) == "" { + return fmt.Errorf("S3 Endpoint is required when enabled") + } + if strings.TrimSpace(bucket) == "" { + return fmt.Errorf("S3 Bucket is required when enabled") + } + if strings.TrimSpace(accessKeyID) == "" { + return fmt.Errorf("S3 Access Key ID is required when enabled") + } + if strings.TrimSpace(secretAccessKey) != "" || hasStoredSecret { + return nil + } + return fmt.Errorf("S3 Secret Access Key is required when enabled") +} + +func findSoraS3ProfileByID(items []service.SoraS3Profile, profileID string) *service.SoraS3Profile { + for idx := range items { + if items[idx].ProfileID == profileID { + return &items[idx] + } + } + return nil +} + +// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置接口) +// GET /api/v1/admin/settings/sora-s3 +func (h *SettingHandler) GetSoraS3Settings(c *gin.Context) { + settings, err := h.settingService.GetSoraS3Settings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, toSoraS3SettingsDTO(settings)) +} + +// ListSoraS3Profiles 获取 Sora S3 多配置 +// GET /api/v1/admin/settings/sora-s3/profiles +func (h *SettingHandler) ListSoraS3Profiles(c *gin.Context) { + result, err := h.settingService.ListSoraS3Profiles(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + items := make([]dto.SoraS3Profile, 0, len(result.Items)) + for idx := range result.Items { + items = append(items, toSoraS3ProfileDTO(result.Items[idx])) + } + response.Success(c, dto.ListSoraS3ProfilesResponse{ + ActiveProfileID: result.ActiveProfileID, + Items: items, + }) +} + +// UpdateSoraS3SettingsRequest 更新/测试 Sora S3 配置请求(兼容旧接口) +type UpdateSoraS3SettingsRequest struct { + ProfileID string `json:"profile_id"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` +} + +type CreateSoraS3ProfileRequest struct { + ProfileID string `json:"profile_id"` + Name string `json:"name"` + SetActive bool `json:"set_active"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` +} + +type UpdateSoraS3ProfileRequest struct { + Name string `json:"name"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` +} + +// CreateSoraS3Profile 创建 Sora S3 配置 +// POST /api/v1/admin/settings/sora-s3/profiles +func (h *SettingHandler) CreateSoraS3Profile(c *gin.Context) { + var req CreateSoraS3ProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if req.DefaultStorageQuotaBytes < 0 { + req.DefaultStorageQuotaBytes = 0 + } + if strings.TrimSpace(req.Name) == "" { + response.BadRequest(c, "Name is required") + return + } + if strings.TrimSpace(req.ProfileID) == "" { + response.BadRequest(c, "Profile ID is required") + return + } + if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, false); err != nil { + response.BadRequest(c, err.Error()) + return + } + + created, err := h.settingService.CreateSoraS3Profile(c.Request.Context(), &service.SoraS3Profile{ + ProfileID: req.ProfileID, + Name: req.Name, + Enabled: req.Enabled, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + CDNURL: req.CDNURL, + DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes, + }, req.SetActive) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, toSoraS3ProfileDTO(*created)) +} + +// UpdateSoraS3Profile 更新 Sora S3 配置 +// PUT /api/v1/admin/settings/sora-s3/profiles/:profile_id +func (h *SettingHandler) UpdateSoraS3Profile(c *gin.Context) { + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Profile ID is required") + return + } + + var req UpdateSoraS3ProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if req.DefaultStorageQuotaBytes < 0 { + req.DefaultStorageQuotaBytes = 0 + } + if strings.TrimSpace(req.Name) == "" { + response.BadRequest(c, "Name is required") + return + } + + existingList, err := h.settingService.ListSoraS3Profiles(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + existing := findSoraS3ProfileByID(existingList.Items, profileID) + if existing == nil { + response.ErrorFrom(c, service.ErrSoraS3ProfileNotFound) + return + } + if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil { + response.BadRequest(c, err.Error()) + return + } + + updated, updateErr := h.settingService.UpdateSoraS3Profile(c.Request.Context(), profileID, &service.SoraS3Profile{ + Name: req.Name, + Enabled: req.Enabled, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + CDNURL: req.CDNURL, + DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes, + }) + if updateErr != nil { + response.ErrorFrom(c, updateErr) + return + } + + response.Success(c, toSoraS3ProfileDTO(*updated)) +} + +// DeleteSoraS3Profile 删除 Sora S3 配置 +// DELETE /api/v1/admin/settings/sora-s3/profiles/:profile_id +func (h *SettingHandler) DeleteSoraS3Profile(c *gin.Context) { + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Profile ID is required") + return + } + if err := h.settingService.DeleteSoraS3Profile(c.Request.Context(), profileID); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"deleted": true}) +} + +// SetActiveSoraS3Profile 切换激活 Sora S3 配置 +// POST /api/v1/admin/settings/sora-s3/profiles/:profile_id/activate +func (h *SettingHandler) SetActiveSoraS3Profile(c *gin.Context) { + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Profile ID is required") + return + } + active, err := h.settingService.SetActiveSoraS3Profile(c.Request.Context(), profileID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, toSoraS3ProfileDTO(*active)) +} + +// UpdateSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置接口) +// PUT /api/v1/admin/settings/sora-s3 +func (h *SettingHandler) UpdateSoraS3Settings(c *gin.Context) { + var req UpdateSoraS3SettingsRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + existing, err := h.settingService.GetSoraS3Settings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + if req.DefaultStorageQuotaBytes < 0 { + req.DefaultStorageQuotaBytes = 0 + } + if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil { + response.BadRequest(c, err.Error()) + return + } + + settings := &service.SoraS3Settings{ + Enabled: req.Enabled, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + CDNURL: req.CDNURL, + DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes, + } + if err := h.settingService.SetSoraS3Settings(c.Request.Context(), settings); err != nil { + response.ErrorFrom(c, err) + return + } + + updatedSettings, err := h.settingService.GetSoraS3Settings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, toSoraS3SettingsDTO(updatedSettings)) +} + +// TestSoraS3Connection 测试 Sora S3 连接(HeadBucket) +// POST /api/v1/admin/settings/sora-s3/test +func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) { + if h.soraS3Storage == nil { + response.Error(c, 500, "S3 存储服务未初始化") + return + } + + var req UpdateSoraS3SettingsRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if !req.Enabled { + response.BadRequest(c, "S3 未启用,无法测试连接") + return + } + + if req.SecretAccessKey == "" { + if req.ProfileID != "" { + profiles, err := h.settingService.ListSoraS3Profiles(c.Request.Context()) + if err == nil { + profile := findSoraS3ProfileByID(profiles.Items, req.ProfileID) + if profile != nil { + req.SecretAccessKey = profile.SecretAccessKey + } + } + } + if req.SecretAccessKey == "" { + existing, err := h.settingService.GetSoraS3Settings(c.Request.Context()) + if err == nil { + req.SecretAccessKey = existing.SecretAccessKey + } + } + } + + testCfg := &service.SoraS3Settings{ + Enabled: true, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + CDNURL: req.CDNURL, + } + if err := h.soraS3Storage.TestConnectionWithSettings(c.Request.Context(), testCfg); err != nil { + response.Error(c, 400, "S3 连接测试失败: "+err.Error()) + return + } + response.Success(c, gin.H{"message": "S3 连接成功"}) +} + // UpdateStreamTimeoutSettingsRequest 更新流超时配置请求 type UpdateStreamTimeoutSettingsRequest struct { Enabled bool `json:"enabled"` diff --git a/backend/internal/handler/admin/snapshot_cache.go b/backend/internal/handler/admin/snapshot_cache.go new file mode 100644 index 00000000..809760a7 --- /dev/null +++ b/backend/internal/handler/admin/snapshot_cache.go @@ -0,0 +1,95 @@ +package admin + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "strings" + "sync" + "time" +) + +type snapshotCacheEntry struct { + ETag string + Payload any + ExpiresAt time.Time +} + +type snapshotCache struct { + mu sync.RWMutex + ttl time.Duration + items map[string]snapshotCacheEntry +} + +func newSnapshotCache(ttl time.Duration) *snapshotCache { + if ttl <= 0 { + ttl = 30 * time.Second + } + return &snapshotCache{ + ttl: ttl, + items: make(map[string]snapshotCacheEntry), + } +} + +func (c *snapshotCache) Get(key string) (snapshotCacheEntry, bool) { + if c == nil || key == "" { + return snapshotCacheEntry{}, false + } + now := time.Now() + + c.mu.RLock() + entry, ok := c.items[key] + c.mu.RUnlock() + if !ok { + return snapshotCacheEntry{}, false + } + if now.After(entry.ExpiresAt) { + c.mu.Lock() + delete(c.items, key) + c.mu.Unlock() + return snapshotCacheEntry{}, false + } + return entry, true +} + +func (c *snapshotCache) Set(key string, payload any) snapshotCacheEntry { + if c == nil { + return snapshotCacheEntry{} + } + entry := snapshotCacheEntry{ + ETag: buildETagFromAny(payload), + Payload: payload, + ExpiresAt: time.Now().Add(c.ttl), + } + if key == "" { + return entry + } + c.mu.Lock() + c.items[key] = entry + c.mu.Unlock() + return entry +} + +func buildETagFromAny(payload any) string { + raw, err := json.Marshal(payload) + if err != nil { + return "" + } + sum := sha256.Sum256(raw) + return "\"" + hex.EncodeToString(sum[:]) + "\"" +} + +func parseBoolQueryWithDefault(raw string, def bool) bool { + value := strings.TrimSpace(strings.ToLower(raw)) + if value == "" { + return def + } + switch value { + case "1", "true", "yes", "on": + return true + case "0", "false", "no", "off": + return false + default: + return def + } +} diff --git a/backend/internal/handler/admin/snapshot_cache_test.go b/backend/internal/handler/admin/snapshot_cache_test.go new file mode 100644 index 00000000..f1c1453e --- /dev/null +++ b/backend/internal/handler/admin/snapshot_cache_test.go @@ -0,0 +1,128 @@ +//go:build unit + +package admin + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestSnapshotCache_SetAndGet(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + + entry := c.Set("key1", map[string]string{"hello": "world"}) + require.NotEmpty(t, entry.ETag) + require.NotNil(t, entry.Payload) + + got, ok := c.Get("key1") + require.True(t, ok) + require.Equal(t, entry.ETag, got.ETag) +} + +func TestSnapshotCache_Expiration(t *testing.T) { + c := newSnapshotCache(1 * time.Millisecond) + + c.Set("key1", "value") + time.Sleep(5 * time.Millisecond) + + _, ok := c.Get("key1") + require.False(t, ok, "expired entry should not be returned") +} + +func TestSnapshotCache_GetEmptyKey(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + _, ok := c.Get("") + require.False(t, ok) +} + +func TestSnapshotCache_GetMiss(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + _, ok := c.Get("nonexistent") + require.False(t, ok) +} + +func TestSnapshotCache_NilReceiver(t *testing.T) { + var c *snapshotCache + _, ok := c.Get("key") + require.False(t, ok) + + entry := c.Set("key", "value") + require.Empty(t, entry.ETag) +} + +func TestSnapshotCache_SetEmptyKey(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + + // Set with empty key should return entry but not store it + entry := c.Set("", "value") + require.NotEmpty(t, entry.ETag) + + _, ok := c.Get("") + require.False(t, ok) +} + +func TestSnapshotCache_DefaultTTL(t *testing.T) { + c := newSnapshotCache(0) + require.Equal(t, 30*time.Second, c.ttl) + + c2 := newSnapshotCache(-1 * time.Second) + require.Equal(t, 30*time.Second, c2.ttl) +} + +func TestSnapshotCache_ETagDeterministic(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + payload := map[string]int{"a": 1, "b": 2} + + entry1 := c.Set("k1", payload) + entry2 := c.Set("k2", payload) + require.Equal(t, entry1.ETag, entry2.ETag, "same payload should produce same ETag") +} + +func TestSnapshotCache_ETagFormat(t *testing.T) { + c := newSnapshotCache(5 * time.Second) + entry := c.Set("k", "test") + // ETag should be quoted hex string: "abcdef..." + require.True(t, len(entry.ETag) > 2) + require.Equal(t, byte('"'), entry.ETag[0]) + require.Equal(t, byte('"'), entry.ETag[len(entry.ETag)-1]) +} + +func TestBuildETagFromAny_UnmarshalablePayload(t *testing.T) { + // channels are not JSON-serializable + etag := buildETagFromAny(make(chan int)) + require.Empty(t, etag) +} + +func TestParseBoolQueryWithDefault(t *testing.T) { + tests := []struct { + name string + raw string + def bool + want bool + }{ + {"empty returns default true", "", true, true}, + {"empty returns default false", "", false, false}, + {"1", "1", false, true}, + {"true", "true", false, true}, + {"TRUE", "TRUE", false, true}, + {"yes", "yes", false, true}, + {"on", "on", false, true}, + {"0", "0", true, false}, + {"false", "false", true, false}, + {"FALSE", "FALSE", true, false}, + {"no", "no", true, false}, + {"off", "off", true, false}, + {"whitespace trimmed", " true ", false, true}, + {"unknown returns default true", "maybe", true, true}, + {"unknown returns default false", "maybe", false, false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := parseBoolQueryWithDefault(tc.raw, tc.def) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/backend/internal/handler/admin/subscription_handler.go b/backend/internal/handler/admin/subscription_handler.go index 51995ab1..e5b6db13 100644 --- a/backend/internal/handler/admin/subscription_handler.go +++ b/backend/internal/handler/admin/subscription_handler.go @@ -1,6 +1,7 @@ package admin import ( + "context" "strconv" "github.com/Wei-Shaw/sub2api/internal/handler/dto" @@ -199,13 +200,20 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) { return } - subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days) - if err != nil { - response.ErrorFrom(c, err) - return + idempotencyPayload := struct { + SubscriptionID int64 `json:"subscription_id"` + Body AdjustSubscriptionRequest `json:"body"` + }{ + SubscriptionID: subscriptionID, + Body: req, } - - response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription)) + executeAdminIdempotentJSON(c, "admin.subscriptions.extend", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + subscription, execErr := h.subscriptionService.ExtendSubscription(ctx, subscriptionID, req.Days) + if execErr != nil { + return nil, execErr + } + return dto.UserSubscriptionFromServiceAdmin(subscription), nil + }) } // Revoke handles revoking a subscription diff --git a/backend/internal/handler/admin/system_handler.go b/backend/internal/handler/admin/system_handler.go index 28c075aa..3e2022c7 100644 --- a/backend/internal/handler/admin/system_handler.go +++ b/backend/internal/handler/admin/system_handler.go @@ -1,11 +1,15 @@ package admin import ( + "context" "net/http" + "strconv" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/sysutil" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -14,12 +18,14 @@ import ( // SystemHandler handles system-related operations type SystemHandler struct { updateSvc *service.UpdateService + lockSvc *service.SystemOperationLockService } // NewSystemHandler creates a new SystemHandler -func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler { +func NewSystemHandler(updateSvc *service.UpdateService, lockSvc *service.SystemOperationLockService) *SystemHandler { return &SystemHandler{ updateSvc: updateSvc, + lockSvc: lockSvc, } } @@ -47,41 +53,125 @@ func (h *SystemHandler) CheckUpdates(c *gin.Context) { // PerformUpdate downloads and applies the update // POST /api/v1/admin/system/update func (h *SystemHandler) PerformUpdate(c *gin.Context) { - if err := h.updateSvc.PerformUpdate(c.Request.Context()); err != nil { - response.Error(c, http.StatusInternalServerError, err.Error()) - return - } - response.Success(c, gin.H{ - "message": "Update completed. Please restart the service.", - "need_restart": true, + operationID := buildSystemOperationID(c, "update") + payload := gin.H{"operation_id": operationID} + executeAdminIdempotentJSON(c, "admin.system.update", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) { + lock, release, err := h.acquireSystemLock(ctx, operationID) + if err != nil { + return nil, err + } + var releaseReason string + succeeded := false + defer func() { + release(releaseReason, succeeded) + }() + + if err := h.updateSvc.PerformUpdate(ctx); err != nil { + releaseReason = "SYSTEM_UPDATE_FAILED" + return nil, err + } + succeeded = true + + return gin.H{ + "message": "Update completed. Please restart the service.", + "need_restart": true, + "operation_id": lock.OperationID(), + }, nil }) } // Rollback restores the previous version // POST /api/v1/admin/system/rollback func (h *SystemHandler) Rollback(c *gin.Context) { - if err := h.updateSvc.Rollback(); err != nil { - response.Error(c, http.StatusInternalServerError, err.Error()) - return - } - response.Success(c, gin.H{ - "message": "Rollback completed. Please restart the service.", - "need_restart": true, + operationID := buildSystemOperationID(c, "rollback") + payload := gin.H{"operation_id": operationID} + executeAdminIdempotentJSON(c, "admin.system.rollback", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) { + lock, release, err := h.acquireSystemLock(ctx, operationID) + if err != nil { + return nil, err + } + var releaseReason string + succeeded := false + defer func() { + release(releaseReason, succeeded) + }() + + if err := h.updateSvc.Rollback(); err != nil { + releaseReason = "SYSTEM_ROLLBACK_FAILED" + return nil, err + } + succeeded = true + + return gin.H{ + "message": "Rollback completed. Please restart the service.", + "need_restart": true, + "operation_id": lock.OperationID(), + }, nil }) } // RestartService restarts the systemd service // POST /api/v1/admin/system/restart func (h *SystemHandler) RestartService(c *gin.Context) { - // Schedule service restart in background after sending response - // This ensures the client receives the success response before the service restarts - go func() { - // Wait a moment to ensure the response is sent - time.Sleep(500 * time.Millisecond) - sysutil.RestartServiceAsync() - }() + operationID := buildSystemOperationID(c, "restart") + payload := gin.H{"operation_id": operationID} + executeAdminIdempotentJSON(c, "admin.system.restart", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) { + lock, release, err := h.acquireSystemLock(ctx, operationID) + if err != nil { + return nil, err + } + succeeded := false + defer func() { + release("", succeeded) + }() - response.Success(c, gin.H{ - "message": "Service restart initiated", + // Schedule service restart in background after sending response + // This ensures the client receives the success response before the service restarts + go func() { + // Wait a moment to ensure the response is sent + time.Sleep(500 * time.Millisecond) + sysutil.RestartServiceAsync() + }() + succeeded = true + return gin.H{ + "message": "Service restart initiated", + "operation_id": lock.OperationID(), + }, nil }) } + +func (h *SystemHandler) acquireSystemLock( + ctx context.Context, + operationID string, +) (*service.SystemOperationLock, func(string, bool), error) { + if h.lockSvc == nil { + return nil, nil, service.ErrIdempotencyStoreUnavail + } + lock, err := h.lockSvc.Acquire(ctx, operationID) + if err != nil { + return nil, nil, err + } + release := func(reason string, succeeded bool) { + releaseCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = h.lockSvc.Release(releaseCtx, lock, succeeded, reason) + } + return lock, release, nil +} + +func buildSystemOperationID(c *gin.Context, operation string) string { + key := strings.TrimSpace(c.GetHeader("Idempotency-Key")) + if key == "" { + return "sysop-" + operation + "-" + strconv.FormatInt(time.Now().UnixNano(), 36) + } + actorScope := "admin:0" + if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok { + actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10) + } + seed := operation + "|" + actorScope + "|" + c.FullPath() + "|" + key + hash := service.HashIdempotencyKey(seed) + if len(hash) > 24 { + hash = hash[:24] + } + return "sysop-" + hash +} diff --git a/backend/internal/handler/admin/usage_cleanup_handler_test.go b/backend/internal/handler/admin/usage_cleanup_handler_test.go index ed1c7cc2..6152d5e9 100644 --- a/backend/internal/handler/admin/usage_cleanup_handler_test.go +++ b/backend/internal/handler/admin/usage_cleanup_handler_test.go @@ -225,6 +225,92 @@ func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) { require.Equal(t, http.StatusBadRequest, recorder.Code) } +func TestUsageHandlerCreateCleanupTaskInvalidRequestType(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 88) + + payload := map[string]any{ + "start_date": "2024-01-01", + "end_date": "2024-01-02", + "timezone": "UTC", + "request_type": "invalid", + } + body, err := json.Marshal(payload) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusBadRequest, recorder.Code) +} + +func TestUsageHandlerCreateCleanupTaskRequestTypePriority(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 99) + + payload := map[string]any{ + "start_date": "2024-01-01", + "end_date": "2024-01-02", + "timezone": "UTC", + "request_type": "ws_v2", + "stream": false, + } + body, err := json.Marshal(payload) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusOK, recorder.Code) + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Len(t, repo.created, 1) + created := repo.created[0] + require.NotNil(t, created.Filters.RequestType) + require.Equal(t, int16(service.RequestTypeWSV2), *created.Filters.RequestType) + require.Nil(t, created.Filters.Stream) +} + +func TestUsageHandlerCreateCleanupTaskWithLegacyStream(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 99) + + payload := map[string]any{ + "start_date": "2024-01-01", + "end_date": "2024-01-02", + "timezone": "UTC", + "stream": true, + } + body, err := json.Marshal(payload) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusOK, recorder.Code) + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Len(t, repo.created, 1) + created := repo.created[0] + require.Nil(t, created.Filters.RequestType) + require.NotNil(t, created.Filters.Stream) + require.True(t, *created.Filters.Stream) +} + func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) { repo := &cleanupRepoStub{} cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} diff --git a/backend/internal/handler/admin/usage_handler.go b/backend/internal/handler/admin/usage_handler.go index 3f3238dd..05fd00f1 100644 --- a/backend/internal/handler/admin/usage_handler.go +++ b/backend/internal/handler/admin/usage_handler.go @@ -1,13 +1,14 @@ package admin import ( - "log" + "context" "net/http" "strconv" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" @@ -50,6 +51,7 @@ type CreateUsageCleanupTaskRequest struct { AccountID *int64 `json:"account_id"` GroupID *int64 `json:"group_id"` Model *string `json:"model"` + RequestType *string `json:"request_type"` Stream *bool `json:"stream"` BillingType *int8 `json:"billing_type"` Timezone string `json:"timezone"` @@ -59,6 +61,15 @@ type CreateUsageCleanupTaskRequest struct { // GET /api/v1/admin/usage func (h *UsageHandler) List(c *gin.Context) { page, pageSize := response.ParsePagination(c) + exactTotal := false + if exactTotalRaw := strings.TrimSpace(c.Query("exact_total")); exactTotalRaw != "" { + parsed, err := strconv.ParseBool(exactTotalRaw) + if err != nil { + response.BadRequest(c, "Invalid exact_total value, use true or false") + return + } + exactTotal = parsed + } // Parse filters var userID, apiKeyID, accountID, groupID int64 @@ -100,8 +111,17 @@ func (h *UsageHandler) List(c *gin.Context) { model := c.Query("model") + var requestType *int16 var stream *bool - if streamStr := c.Query("stream"); streamStr != "" { + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { val, err := strconv.ParseBool(streamStr) if err != nil { response.BadRequest(c, "Invalid stream value, use true or false") @@ -151,10 +171,12 @@ func (h *UsageHandler) List(c *gin.Context) { AccountID: accountID, GroupID: groupID, Model: model, + RequestType: requestType, Stream: stream, BillingType: billingType, StartTime: startTime, EndTime: endTime, + ExactTotal: exactTotal, } records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters) @@ -213,8 +235,17 @@ func (h *UsageHandler) Stats(c *gin.Context) { model := c.Query("model") + var requestType *int16 var stream *bool - if streamStr := c.Query("stream"); streamStr != "" { + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { val, err := strconv.ParseBool(streamStr) if err != nil { response.BadRequest(c, "Invalid stream value, use true or false") @@ -277,6 +308,7 @@ func (h *UsageHandler) Stats(c *gin.Context) { AccountID: accountID, GroupID: groupID, Model: model, + RequestType: requestType, Stream: stream, BillingType: billingType, StartTime: &startTime, @@ -378,11 +410,11 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) { operator = subject.UserID } page, pageSize := response.ParsePagination(c) - log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize) + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize) params := pagination.PaginationParams{Page: page, PageSize: pageSize} tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params) if err != nil { - log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err) + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err) response.ErrorFrom(c, err) return } @@ -390,7 +422,7 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) { for i := range tasks { out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i])) } - log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize) + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize) response.Paginated(c, out, result.Total, page, pageSize) } @@ -431,6 +463,19 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { } endTime = endTime.Add(24*time.Hour - time.Nanosecond) + var requestType *int16 + stream := req.Stream + if req.RequestType != nil { + parsed, err := service.ParseUsageRequestType(*req.RequestType) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + stream = nil + } + filters := service.UsageCleanupFilters{ StartTime: startTime, EndTime: endTime, @@ -439,7 +484,8 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { AccountID: req.AccountID, GroupID: req.GroupID, Model: req.Model, - Stream: req.Stream, + RequestType: requestType, + Stream: stream, BillingType: req.BillingType, } @@ -463,38 +509,50 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { if filters.Model != nil { model = *filters.Model } - var stream any + var streamValue any if filters.Stream != nil { - stream = *filters.Stream + streamValue = *filters.Stream + } + var requestTypeName any + if filters.RequestType != nil { + requestTypeName = service.RequestTypeFromInt16(*filters.RequestType).String() } var billingType any if filters.BillingType != nil { billingType = *filters.BillingType } - log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q", - subject.UserID, - filters.StartTime.Format(time.RFC3339), - filters.EndTime.Format(time.RFC3339), - userID, - apiKeyID, - accountID, - groupID, - model, - stream, - billingType, - req.Timezone, - ) - - task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID) - if err != nil { - log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err) - response.ErrorFrom(c, err) - return + idempotencyPayload := struct { + OperatorID int64 `json:"operator_id"` + Body CreateUsageCleanupTaskRequest `json:"body"` + }{ + OperatorID: subject.UserID, + Body: req, } + executeAdminIdempotentJSON(c, "admin.usage.cleanup_tasks.create", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v request_type=%v stream=%v billing_type=%v tz=%q", + subject.UserID, + filters.StartTime.Format(time.RFC3339), + filters.EndTime.Format(time.RFC3339), + userID, + apiKeyID, + accountID, + groupID, + model, + requestTypeName, + streamValue, + billingType, + req.Timezone, + ) - log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status) - response.Success(c, dto.UsageCleanupTaskFromService(task)) + task, err := h.cleanupService.CreateTask(ctx, filters, subject.UserID) + if err != nil { + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err) + return nil, err + } + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status) + return dto.UsageCleanupTaskFromService(task), nil + }) } // CancelCleanupTask handles canceling a usage cleanup task @@ -515,12 +573,12 @@ func (h *UsageHandler) CancelCleanupTask(c *gin.Context) { response.BadRequest(c, "Invalid task id") return } - log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID) + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID) if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil { - log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err) + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err) response.ErrorFrom(c, err) return } - log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID) + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID) response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled}) } diff --git a/backend/internal/handler/admin/usage_handler_request_type_test.go b/backend/internal/handler/admin/usage_handler_request_type_test.go new file mode 100644 index 00000000..3f158316 --- /dev/null +++ b/backend/internal/handler/admin/usage_handler_request_type_test.go @@ -0,0 +1,140 @@ +package admin + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type adminUsageRepoCapture struct { + service.UsageLogRepository + listFilters usagestats.UsageLogFilters + statsFilters usagestats.UsageLogFilters +} + +func (s *adminUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { + s.listFilters = filters + return []service.UsageLog{}, &pagination.PaginationResult{ + Total: 0, + Page: params.Page, + PageSize: params.PageSize, + Pages: 0, + }, nil +} + +func (s *adminUsageRepoCapture) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) { + s.statsFilters = filters + return &usagestats.UsageStats{}, nil +} + +func newAdminUsageRequestTypeTestRouter(repo *adminUsageRepoCapture) *gin.Engine { + gin.SetMode(gin.TestMode) + usageSvc := service.NewUsageService(repo, nil, nil, nil) + handler := NewUsageHandler(usageSvc, nil, nil, nil) + router := gin.New() + router.GET("/admin/usage", handler.List) + router.GET("/admin/usage/stats", handler.Stats) + return router +} + +func TestAdminUsageListRequestTypePriority(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=ws_v2&stream=false", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.listFilters.RequestType) + require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType) + require.Nil(t, repo.listFilters.Stream) +} + +func TestAdminUsageListInvalidRequestType(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestAdminUsageListInvalidStream(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage?stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestAdminUsageListExactTotalTrue(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=true", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.True(t, repo.listFilters.ExactTotal) +} + +func TestAdminUsageListInvalidExactTotal(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage?exact_total=oops", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestAdminUsageStatsRequestTypePriority(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=stream&stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.statsFilters.RequestType) + require.Equal(t, int16(service.RequestTypeStream), *repo.statsFilters.RequestType) + require.Nil(t, repo.statsFilters.Stream) +} + +func TestAdminUsageStatsInvalidRequestType(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=oops", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestAdminUsageStatsInvalidStream(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?stream=oops", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} diff --git a/backend/internal/handler/admin/user_attribute_handler.go b/backend/internal/handler/admin/user_attribute_handler.go index 2f326279..3f84076e 100644 --- a/backend/internal/handler/admin/user_attribute_handler.go +++ b/backend/internal/handler/admin/user_attribute_handler.go @@ -1,7 +1,9 @@ package admin import ( + "encoding/json" "strconv" + "time" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -67,6 +69,8 @@ type BatchUserAttributesResponse struct { Attributes map[int64]map[int64]string `json:"attributes"` } +var userAttributesBatchCache = newSnapshotCache(30 * time.Second) + // AttributeDefinitionResponse represents attribute definition response type AttributeDefinitionResponse struct { ID int64 `json:"id"` @@ -327,16 +331,32 @@ func (h *UserAttributeHandler) GetBatchUserAttributes(c *gin.Context) { return } - if len(req.UserIDs) == 0 { + userIDs := normalizeInt64IDList(req.UserIDs) + if len(userIDs) == 0 { response.Success(c, BatchUserAttributesResponse{Attributes: map[int64]map[int64]string{}}) return } - attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), req.UserIDs) + keyRaw, _ := json.Marshal(struct { + UserIDs []int64 `json:"user_ids"` + }{ + UserIDs: userIDs, + }) + cacheKey := string(keyRaw) + if cached, ok := userAttributesBatchCache.Get(cacheKey); ok { + c.Header("X-Snapshot-Cache", "hit") + response.Success(c, cached.Payload) + return + } + + attrs, err := h.attrService.GetBatchUserAttributes(c.Request.Context(), userIDs) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, BatchUserAttributesResponse{Attributes: attrs}) + payload := BatchUserAttributesResponse{Attributes: attrs} + userAttributesBatchCache.Set(cacheKey, payload) + c.Header("X-Snapshot-Cache", "miss") + response.Success(c, payload) } diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index efb9abb5..5a55ab14 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -1,6 +1,7 @@ package admin import ( + "context" "strconv" "strings" @@ -33,13 +34,14 @@ func NewUserHandler(adminService service.AdminService, concurrencyService *servi // CreateUserRequest represents admin create user request type CreateUserRequest struct { - Email string `json:"email" binding:"required,email"` - Password string `json:"password" binding:"required,min=6"` - Username string `json:"username"` - Notes string `json:"notes"` - Balance float64 `json:"balance"` - Concurrency int `json:"concurrency"` - AllowedGroups []int64 `json:"allowed_groups"` + Email string `json:"email" binding:"required,email"` + Password string `json:"password" binding:"required,min=6"` + Username string `json:"username"` + Notes string `json:"notes"` + Balance float64 `json:"balance"` + Concurrency int `json:"concurrency"` + AllowedGroups []int64 `json:"allowed_groups"` + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` } // UpdateUserRequest represents admin update user request @@ -55,7 +57,8 @@ type UpdateUserRequest struct { AllowedGroups *[]int64 `json:"allowed_groups"` // GroupRates 用户专属分组倍率配置 // map[groupID]*rate,nil 表示删除该分组的专属倍率 - GroupRates map[int64]*float64 `json:"group_rates"` + GroupRates map[int64]*float64 `json:"group_rates"` + SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"` } // UpdateBalanceRequest represents balance update request @@ -78,8 +81,8 @@ func (h *UserHandler) List(c *gin.Context) { search := c.Query("search") // 标准化和验证 search 参数 search = strings.TrimSpace(search) - if len(search) > 100 { - search = search[:100] + if runes := []rune(search); len(runes) > 100 { + search = string(runes[:100]) } filters := service.UserListFilters{ @@ -88,6 +91,10 @@ func (h *UserHandler) List(c *gin.Context) { Search: search, Attributes: parseAttributeFilters(c), } + if raw, ok := c.GetQuery("include_subscriptions"); ok { + includeSubscriptions := parseBoolQueryWithDefault(raw, true) + filters.IncludeSubscriptions = &includeSubscriptions + } users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters) if err != nil { @@ -173,13 +180,14 @@ func (h *UserHandler) Create(c *gin.Context) { } user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{ - Email: req.Email, - Password: req.Password, - Username: req.Username, - Notes: req.Notes, - Balance: req.Balance, - Concurrency: req.Concurrency, - AllowedGroups: req.AllowedGroups, + Email: req.Email, + Password: req.Password, + Username: req.Username, + Notes: req.Notes, + Balance: req.Balance, + Concurrency: req.Concurrency, + AllowedGroups: req.AllowedGroups, + SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, }) if err != nil { response.ErrorFrom(c, err) @@ -206,15 +214,16 @@ func (h *UserHandler) Update(c *gin.Context) { // 使用指针类型直接传递,nil 表示未提供该字段 user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{ - Email: req.Email, - Password: req.Password, - Username: req.Username, - Notes: req.Notes, - Balance: req.Balance, - Concurrency: req.Concurrency, - Status: req.Status, - AllowedGroups: req.AllowedGroups, - GroupRates: req.GroupRates, + Email: req.Email, + Password: req.Password, + Username: req.Username, + Notes: req.Notes, + Balance: req.Balance, + Concurrency: req.Concurrency, + Status: req.Status, + AllowedGroups: req.AllowedGroups, + GroupRates: req.GroupRates, + SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, }) if err != nil { response.ErrorFrom(c, err) @@ -257,13 +266,20 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) { return } - user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes) - if err != nil { - response.ErrorFrom(c, err) - return + idempotencyPayload := struct { + UserID int64 `json:"user_id"` + Body UpdateBalanceRequest `json:"body"` + }{ + UserID: userID, + Body: req, } - - response.Success(c, dto.UserFromServiceAdmin(user)) + executeAdminIdempotentJSON(c, "admin.users.balance.update", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + user, execErr := h.adminService.UpdateUserBalance(ctx, userID, req.Balance, req.Operation, req.Notes) + if execErr != nil { + return nil, execErr + } + return dto.UserFromServiceAdmin(user), nil + }) } // GetUserAPIKeys handles getting user's API keys diff --git a/backend/internal/handler/api_key_handler.go b/backend/internal/handler/api_key_handler.go index f1a18ad2..951aed08 100644 --- a/backend/internal/handler/api_key_handler.go +++ b/backend/internal/handler/api_key_handler.go @@ -2,7 +2,9 @@ package handler import ( + "context" "strconv" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" @@ -35,6 +37,11 @@ type CreateAPIKeyRequest struct { IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 Quota *float64 `json:"quota"` // 配额限制 (USD) ExpiresInDays *int `json:"expires_in_days"` // 过期天数 + + // Rate limit fields (0 = unlimited) + RateLimit5h *float64 `json:"rate_limit_5h"` + RateLimit1d *float64 `json:"rate_limit_1d"` + RateLimit7d *float64 `json:"rate_limit_7d"` } // UpdateAPIKeyRequest represents the update API key request payload @@ -47,6 +54,12 @@ type UpdateAPIKeyRequest struct { Quota *float64 `json:"quota"` // 配额限制 (USD), 0=无限制 ExpiresAt *string `json:"expires_at"` // 过期时间 (ISO 8601) ResetQuota *bool `json:"reset_quota"` // 重置已用配额 + + // Rate limit fields (nil = no change, 0 = unlimited) + RateLimit5h *float64 `json:"rate_limit_5h"` + RateLimit1d *float64 `json:"rate_limit_1d"` + RateLimit7d *float64 `json:"rate_limit_7d"` + ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // 重置限速用量 } // List handles listing user's API keys with pagination @@ -61,7 +74,23 @@ func (h *APIKeyHandler) List(c *gin.Context) { page, pageSize := response.ParsePagination(c) params := pagination.PaginationParams{Page: page, PageSize: pageSize} - keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params) + // Parse filter parameters + var filters service.APIKeyListFilters + if search := strings.TrimSpace(c.Query("search")); search != "" { + if len(search) > 100 { + search = search[:100] + } + filters.Search = search + } + filters.Status = c.Query("status") + if groupIDStr := c.Query("group_id"); groupIDStr != "" { + gid, err := strconv.ParseInt(groupIDStr, 10, 64) + if err == nil { + filters.GroupID = &gid + } + } + + keys, result, err := h.apiKeyService.List(c.Request.Context(), subject.UserID, params, filters) if err != nil { response.ErrorFrom(c, err) return @@ -130,13 +159,23 @@ func (h *APIKeyHandler) Create(c *gin.Context) { if req.Quota != nil { svcReq.Quota = *req.Quota } - key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq) - if err != nil { - response.ErrorFrom(c, err) - return + if req.RateLimit5h != nil { + svcReq.RateLimit5h = *req.RateLimit5h + } + if req.RateLimit1d != nil { + svcReq.RateLimit1d = *req.RateLimit1d + } + if req.RateLimit7d != nil { + svcReq.RateLimit7d = *req.RateLimit7d } - response.Success(c, dto.APIKeyFromService(key)) + executeUserIdempotentJSON(c, "user.api_keys.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + key, err := h.apiKeyService.Create(ctx, subject.UserID, svcReq) + if err != nil { + return nil, err + } + return dto.APIKeyFromService(key), nil + }) } // Update handles updating an API key @@ -161,10 +200,14 @@ func (h *APIKeyHandler) Update(c *gin.Context) { } svcReq := service.UpdateAPIKeyRequest{ - IPWhitelist: req.IPWhitelist, - IPBlacklist: req.IPBlacklist, - Quota: req.Quota, - ResetQuota: req.ResetQuota, + IPWhitelist: req.IPWhitelist, + IPBlacklist: req.IPBlacklist, + Quota: req.Quota, + ResetQuota: req.ResetQuota, + RateLimit5h: req.RateLimit5h, + RateLimit1d: req.RateLimit1d, + RateLimit7d: req.RateLimit7d, + ResetRateLimitUsage: req.ResetRateLimitUsage, } if req.Name != "" { svcReq.Name = &req.Name diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 34ed63bc..1ffa9d71 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -2,6 +2,7 @@ package handler import ( "log/slog" + "strings" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler/dto" @@ -112,12 +113,10 @@ func (h *AuthHandler) Register(c *gin.Context) { return } - // Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过) - if req.VerifyCode == "" { - if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil { - response.ErrorFrom(c, err) - return - } + // Turnstile 验证(邮箱验证码注册场景避免重复校验一次性 token) + if err := h.authService.VerifyTurnstileForRegister(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c), req.VerifyCode); err != nil { + response.ErrorFrom(c, err) + return } _, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode) @@ -448,17 +447,12 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) { return } - // Build frontend base URL from request - scheme := "https" - if c.Request.TLS == nil { - // Check X-Forwarded-Proto header (common in reverse proxy setups) - if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" { - scheme = proto - } else { - scheme = "http" - } + frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL) + if frontendBaseURL == "" { + slog.Error("server.frontend_url not configured; cannot build password reset link") + response.InternalError(c, "Password reset is not configured") + return } - frontendBaseURL := scheme + "://" + c.Request.Host // Request password reset (async) // Note: This returns success even if email doesn't exist (to prevent enumeration) diff --git a/backend/internal/handler/dto/api_key_mapper_last_used_test.go b/backend/internal/handler/dto/api_key_mapper_last_used_test.go new file mode 100644 index 00000000..99644ced --- /dev/null +++ b/backend/internal/handler/dto/api_key_mapper_last_used_test.go @@ -0,0 +1,40 @@ +package dto + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestAPIKeyFromService_MapsLastUsedAt(t *testing.T) { + lastUsed := time.Now().UTC().Truncate(time.Second) + src := &service.APIKey{ + ID: 1, + UserID: 2, + Key: "sk-map-last-used", + Name: "Mapper", + Status: service.StatusActive, + LastUsedAt: &lastUsed, + } + + out := APIKeyFromService(src) + require.NotNil(t, out) + require.NotNil(t, out.LastUsedAt) + require.WithinDuration(t, lastUsed, *out.LastUsedAt, time.Second) +} + +func TestAPIKeyFromService_MapsNilLastUsedAt(t *testing.T) { + src := &service.APIKey{ + ID: 1, + UserID: 2, + Key: "sk-map-last-used-nil", + Name: "MapperNil", + Status: service.StatusActive, + } + + out := APIKeyFromService(src) + require.NotNil(t, out) + require.Nil(t, out.LastUsedAt) +} diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 2caf6847..fe2a1d77 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -2,6 +2,7 @@ package dto import ( + "strconv" "time" "github.com/Wei-Shaw/sub2api/internal/service" @@ -58,9 +59,11 @@ func UserFromServiceAdmin(u *service.User) *AdminUser { return nil } return &AdminUser{ - User: *base, - Notes: u.Notes, - GroupRates: u.GroupRates, + User: *base, + Notes: u.Notes, + GroupRates: u.GroupRates, + SoraStorageQuotaBytes: u.SoraStorageQuotaBytes, + SoraStorageUsedBytes: u.SoraStorageUsedBytes, } } @@ -69,21 +72,31 @@ func APIKeyFromService(k *service.APIKey) *APIKey { return nil } return &APIKey{ - ID: k.ID, - UserID: k.UserID, - Key: k.Key, - Name: k.Name, - GroupID: k.GroupID, - Status: k.Status, - IPWhitelist: k.IPWhitelist, - IPBlacklist: k.IPBlacklist, - Quota: k.Quota, - QuotaUsed: k.QuotaUsed, - ExpiresAt: k.ExpiresAt, - CreatedAt: k.CreatedAt, - UpdatedAt: k.UpdatedAt, - User: UserFromServiceShallow(k.User), - Group: GroupFromServiceShallow(k.Group), + ID: k.ID, + UserID: k.UserID, + Key: k.Key, + Name: k.Name, + GroupID: k.GroupID, + Status: k.Status, + IPWhitelist: k.IPWhitelist, + IPBlacklist: k.IPBlacklist, + LastUsedAt: k.LastUsedAt, + Quota: k.Quota, + QuotaUsed: k.QuotaUsed, + ExpiresAt: k.ExpiresAt, + CreatedAt: k.CreatedAt, + UpdatedAt: k.UpdatedAt, + RateLimit5h: k.RateLimit5h, + RateLimit1d: k.RateLimit1d, + RateLimit7d: k.RateLimit7d, + Usage5h: k.Usage5h, + Usage1d: k.Usage1d, + Usage7d: k.Usage7d, + Window5hStart: k.Window5hStart, + Window1dStart: k.Window1dStart, + Window7dStart: k.Window7dStart, + User: UserFromServiceShallow(k.User), + Group: GroupFromServiceShallow(k.Group), } } @@ -129,24 +142,28 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { func groupFromServiceBase(g *service.Group) Group { return Group{ - ID: g.ID, - Name: g.Name, - Description: g.Description, - Platform: g.Platform, - RateMultiplier: g.RateMultiplier, - IsExclusive: g.IsExclusive, - Status: g.Status, - SubscriptionType: g.SubscriptionType, - DailyLimitUSD: g.DailyLimitUSD, - WeeklyLimitUSD: g.WeeklyLimitUSD, - MonthlyLimitUSD: g.MonthlyLimitUSD, - ImagePrice1K: g.ImagePrice1K, - ImagePrice2K: g.ImagePrice2K, - ImagePrice4K: g.ImagePrice4K, - ClaudeCodeOnly: g.ClaudeCodeOnly, - FallbackGroupID: g.FallbackGroupID, - // 无效请求兜底分组 + ID: g.ID, + Name: g.Name, + Description: g.Description, + Platform: g.Platform, + RateMultiplier: g.RateMultiplier, + IsExclusive: g.IsExclusive, + Status: g.Status, + SubscriptionType: g.SubscriptionType, + DailyLimitUSD: g.DailyLimitUSD, + WeeklyLimitUSD: g.WeeklyLimitUSD, + MonthlyLimitUSD: g.MonthlyLimitUSD, + ImagePrice1K: g.ImagePrice1K, + ImagePrice2K: g.ImagePrice2K, + ImagePrice4K: g.ImagePrice4K, + SoraImagePrice360: g.SoraImagePrice360, + SoraImagePrice540: g.SoraImagePrice540, + SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD, + ClaudeCodeOnly: g.ClaudeCodeOnly, + FallbackGroupID: g.FallbackGroupID, FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, + SoraStorageQuotaBytes: g.SoraStorageQuotaBytes, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } @@ -201,6 +218,17 @@ func AccountFromServiceShallow(a *service.Account) *Account { if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 { out.SessionIdleTimeoutMin = &idleTimeout } + if rpm := a.GetBaseRPM(); rpm > 0 { + out.BaseRPM = &rpm + strategy := a.GetRPMStrategy() + out.RPMStrategy = &strategy + buffer := a.GetRPMStickyBuffer() + out.RPMStickyBuffer = &buffer + } + // 用户消息队列模式 + if mode := a.GetUserMsgQueueMode(); mode != "" { + out.UserMsgQueueMode = &mode + } // TLS指纹伪装开关 if a.IsTLSFingerprintEnabled() { enabled := true @@ -211,6 +239,13 @@ func AccountFromServiceShallow(a *service.Account) *Account { enabled := true out.EnableSessionIDMasking = &enabled } + // 缓存 TTL 强制替换 + if a.IsCacheTTLOverrideEnabled() { + enabled := true + out.CacheTTLOverrideEnabled = &enabled + target := a.GetCacheTTLOverrideTarget() + out.CacheTTLOverrideTarget = &target + } } return out @@ -271,7 +306,6 @@ func ProxyFromService(p *service.Proxy) *Proxy { Host: p.Host, Port: p.Port, Username: p.Username, - Password: p.Password, Status: p.Status, CreatedAt: p.CreatedAt, UpdatedAt: p.UpdatedAt, @@ -293,6 +327,56 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi CountryCode: p.CountryCode, Region: p.Region, City: p.City, + QualityStatus: p.QualityStatus, + QualityScore: p.QualityScore, + QualityGrade: p.QualityGrade, + QualitySummary: p.QualitySummary, + QualityChecked: p.QualityChecked, + } +} + +// ProxyFromServiceAdmin converts a service Proxy to AdminProxy DTO for admin users. +// It includes the password field - user-facing endpoints must not use this. +func ProxyFromServiceAdmin(p *service.Proxy) *AdminProxy { + if p == nil { + return nil + } + base := ProxyFromService(p) + if base == nil { + return nil + } + return &AdminProxy{ + Proxy: *base, + Password: p.Password, + } +} + +// ProxyWithAccountCountFromServiceAdmin converts a service ProxyWithAccountCount to AdminProxyWithAccountCount DTO. +// It includes the password field - user-facing endpoints must not use this. +func ProxyWithAccountCountFromServiceAdmin(p *service.ProxyWithAccountCount) *AdminProxyWithAccountCount { + if p == nil { + return nil + } + admin := ProxyFromServiceAdmin(&p.Proxy) + if admin == nil { + return nil + } + return &AdminProxyWithAccountCount{ + AdminProxy: *admin, + AccountCount: p.AccountCount, + LatencyMs: p.LatencyMs, + LatencyStatus: p.LatencyStatus, + LatencyMessage: p.LatencyMessage, + IPAddress: p.IPAddress, + Country: p.Country, + CountryCode: p.CountryCode, + Region: p.Region, + City: p.City, + QualityStatus: p.QualityStatus, + QualityScore: p.QualityScore, + QualityGrade: p.QualityGrade, + QualitySummary: p.QualitySummary, + QualityChecked: p.QualityChecked, } } @@ -368,6 +452,8 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary { func usageLogFromServiceUser(l *service.UsageLog) UsageLog { // 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。 + requestType := l.EffectiveRequestType() + stream, openAIWSMode := service.ApplyLegacyRequestFields(requestType, l.Stream, l.OpenAIWSMode) return UsageLog{ ID: l.ID, UserID: l.UserID, @@ -392,12 +478,16 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { ActualCost: l.ActualCost, RateMultiplier: l.RateMultiplier, BillingType: l.BillingType, - Stream: l.Stream, + RequestType: requestType.String(), + Stream: stream, + OpenAIWSMode: openAIWSMode, DurationMs: l.DurationMs, FirstTokenMs: l.FirstTokenMs, ImageCount: l.ImageCount, ImageSize: l.ImageSize, + MediaType: l.MediaType, UserAgent: l.UserAgent, + CacheTTLOverridden: l.CacheTTLOverridden, CreatedAt: l.CreatedAt, User: UserFromServiceShallow(l.User), APIKey: APIKeyFromService(l.APIKey), @@ -445,6 +535,7 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa AccountID: task.Filters.AccountID, GroupID: task.Filters.GroupID, Model: task.Filters.Model, + RequestType: requestTypeStringPtr(task.Filters.RequestType), Stream: task.Filters.Stream, BillingType: task.Filters.BillingType, }, @@ -460,6 +551,14 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa } } +func requestTypeStringPtr(requestType *int16) *string { + if requestType == nil { + return nil + } + value := service.RequestTypeFromInt16(*requestType).String() + return &value +} + func SettingFromService(s *service.Setting) *Setting { if s == nil { return nil @@ -524,11 +623,18 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult for i := range r.Subscriptions { subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i])) } + statuses := make(map[string]string, len(r.Statuses)) + for userID, status := range r.Statuses { + statuses[strconv.FormatInt(userID, 10)] = status + } return &BulkAssignResult{ SuccessCount: r.SuccessCount, + CreatedCount: r.CreatedCount, + ReusedCount: r.ReusedCount, FailedCount: r.FailedCount, Subscriptions: subs, Errors: r.Errors, + Statuses: statuses, } } diff --git a/backend/internal/handler/dto/mappers_usage_test.go b/backend/internal/handler/dto/mappers_usage_test.go new file mode 100644 index 00000000..d716bdc4 --- /dev/null +++ b/backend/internal/handler/dto/mappers_usage_test.go @@ -0,0 +1,73 @@ +package dto + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestUsageLogFromService_IncludesOpenAIWSMode(t *testing.T) { + t.Parallel() + + wsLog := &service.UsageLog{ + RequestID: "req_1", + Model: "gpt-5.3-codex", + OpenAIWSMode: true, + } + httpLog := &service.UsageLog{ + RequestID: "resp_1", + Model: "gpt-5.3-codex", + OpenAIWSMode: false, + } + + require.True(t, UsageLogFromService(wsLog).OpenAIWSMode) + require.False(t, UsageLogFromService(httpLog).OpenAIWSMode) + require.True(t, UsageLogFromServiceAdmin(wsLog).OpenAIWSMode) + require.False(t, UsageLogFromServiceAdmin(httpLog).OpenAIWSMode) +} + +func TestUsageLogFromService_PrefersRequestTypeForLegacyFields(t *testing.T) { + t.Parallel() + + log := &service.UsageLog{ + RequestID: "req_2", + Model: "gpt-5.3-codex", + RequestType: service.RequestTypeWSV2, + Stream: false, + OpenAIWSMode: false, + } + + userDTO := UsageLogFromService(log) + adminDTO := UsageLogFromServiceAdmin(log) + + require.Equal(t, "ws_v2", userDTO.RequestType) + require.True(t, userDTO.Stream) + require.True(t, userDTO.OpenAIWSMode) + require.Equal(t, "ws_v2", adminDTO.RequestType) + require.True(t, adminDTO.Stream) + require.True(t, adminDTO.OpenAIWSMode) +} + +func TestUsageCleanupTaskFromService_RequestTypeMapping(t *testing.T) { + t.Parallel() + + requestType := int16(service.RequestTypeStream) + task := &service.UsageCleanupTask{ + ID: 1, + Status: service.UsageCleanupStatusPending, + Filters: service.UsageCleanupFilters{ + RequestType: &requestType, + }, + } + + dtoTask := UsageCleanupTaskFromService(task) + require.NotNil(t, dtoTask) + require.NotNil(t, dtoTask.Filters.RequestType) + require.Equal(t, "stream", *dtoTask.Filters.RequestType) +} + +func TestRequestTypeStringPtrNil(t *testing.T) { + t.Parallel() + require.Nil(t, requestTypeStringPtr(nil)) +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index be94bc16..c34c6de1 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -1,14 +1,30 @@ package dto +import ( + "encoding/json" + "strings" +) + +// CustomMenuItem represents a user-configured custom menu entry. +type CustomMenuItem struct { + ID string `json:"id"` + Label string `json:"label"` + IconSVG string `json:"icon_svg"` + URL string `json:"url"` + Visibility string `json:"visibility"` // "user" or "admin" + SortOrder int `json:"sort_order"` +} + // SystemSettings represents the admin settings API response payload. type SystemSettings struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 - TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置 + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 + TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置 SMTPHost string `json:"smtp_host"` SMTPPort int `json:"smtp_port"` @@ -27,19 +43,22 @@ type SystemSettings struct { LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo"` - SiteSubtitle string `json:"site_subtitle"` - APIBaseURL string `json:"api_base_url"` - ContactInfo string `json:"contact_info"` - DocURL string `json:"doc_url"` - HomeContent string `json:"home_content"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL string `json:"purchase_subscription_url"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + APIBaseURL string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocURL string `json:"doc_url"` + HomeContent string `json:"home_content"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL string `json:"purchase_subscription_url"` + SoraClientEnabled bool `json:"sora_client_enabled"` + CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` - DefaultConcurrency int `json:"default_concurrency"` - DefaultBalance float64 `json:"default_balance"` + DefaultConcurrency int `json:"default_concurrency"` + DefaultBalance float64 `json:"default_balance"` + DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"` // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` @@ -57,29 +76,80 @@ type SystemSettings struct { OpsRealtimeMonitoringEnabled bool `json:"ops_realtime_monitoring_enabled"` OpsQueryModeDefault string `json:"ops_query_mode_default"` OpsMetricsIntervalSeconds int `json:"ops_metrics_interval_seconds"` + + MinClaudeCodeVersion string `json:"min_claude_code_version"` + + // 分组隔离 + AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"` +} + +type DefaultSubscriptionSetting struct { + GroupID int64 `json:"group_id"` + ValidityDays int `json:"validity_days"` } type PublicSettings struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 - TurnstileEnabled bool `json:"turnstile_enabled"` - TurnstileSiteKey string `json:"turnstile_site_key"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo"` - SiteSubtitle string `json:"site_subtitle"` - APIBaseURL string `json:"api_base_url"` - ContactInfo string `json:"contact_info"` - DocURL string `json:"doc_url"` - HomeContent string `json:"home_content"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL string `json:"purchase_subscription_url"` - LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` - Version string `json:"version"` + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 + TurnstileEnabled bool `json:"turnstile_enabled"` + TurnstileSiteKey string `json:"turnstile_site_key"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + APIBaseURL string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocURL string `json:"doc_url"` + HomeContent string `json:"home_content"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL string `json:"purchase_subscription_url"` + CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` + LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + SoraClientEnabled bool `json:"sora_client_enabled"` + Version string `json:"version"` +} + +// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段) +type SoraS3Settings struct { + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` +} + +// SoraS3Profile Sora S3 存储配置项 DTO(响应用,不含敏感字段) +type SoraS3Profile struct { + ProfileID string `json:"profile_id"` + Name string `json:"name"` + IsActive bool `json:"is_active"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` + UpdatedAt string `json:"updated_at"` +} + +// ListSoraS3ProfilesResponse Sora S3 配置列表响应 +type ListSoraS3ProfilesResponse struct { + ActiveProfileID string `json:"active_profile_id"` + Items []SoraS3Profile `json:"items"` } // StreamTimeoutSettings 流超时处理配置 DTO @@ -90,3 +160,29 @@ type StreamTimeoutSettings struct { ThresholdCount int `json:"threshold_count"` ThresholdWindowMinutes int `json:"threshold_window_minutes"` } + +// ParseCustomMenuItems parses a JSON string into a slice of CustomMenuItem. +// Returns empty slice on empty/invalid input. +func ParseCustomMenuItems(raw string) []CustomMenuItem { + raw = strings.TrimSpace(raw) + if raw == "" || raw == "[]" { + return []CustomMenuItem{} + } + var items []CustomMenuItem + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return []CustomMenuItem{} + } + return items +} + +// ParseUserVisibleMenuItems parses custom menu items and filters out admin-only entries. +func ParseUserVisibleMenuItems(raw string) []CustomMenuItem { + items := ParseCustomMenuItems(raw) + filtered := make([]CustomMenuItem, 0, len(items)) + for _, item := range items { + if item.Visibility != "admin" { + filtered = append(filtered, item) + } + } + return filtered +} diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 2c1ae83c..920615f7 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -26,7 +26,9 @@ type AdminUser struct { Notes string `json:"notes"` // GroupRates 用户专属分组倍率配置 // map[groupID]rateMultiplier - GroupRates map[int64]float64 `json:"group_rates,omitempty"` + GroupRates map[int64]float64 `json:"group_rates,omitempty"` + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` + SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes"` } type APIKey struct { @@ -38,12 +40,24 @@ type APIKey struct { Status string `json:"status"` IPWhitelist []string `json:"ip_whitelist"` IPBlacklist []string `json:"ip_blacklist"` + LastUsedAt *time.Time `json:"last_used_at"` Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited) QuotaUsed float64 `json:"quota_used"` // Used quota amount in USD ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = never expires) CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` + // Rate limit fields + RateLimit5h float64 `json:"rate_limit_5h"` + RateLimit1d float64 `json:"rate_limit_1d"` + RateLimit7d float64 `json:"rate_limit_7d"` + Usage5h float64 `json:"usage_5h"` + Usage1d float64 `json:"usage_1d"` + Usage7d float64 `json:"usage_7d"` + Window5hStart *time.Time `json:"window_5h_start"` + Window1dStart *time.Time `json:"window_1d_start"` + Window7dStart *time.Time `json:"window_7d_start"` + User *User `json:"user,omitempty"` Group *Group `json:"group,omitempty"` } @@ -67,12 +81,21 @@ type Group struct { ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` + // Sora 按次计费配置 + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` + // Claude Code 客户端限制 ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` // 无效请求兜底分组 FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` + // Sora 存储配额 + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` + CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } @@ -141,6 +164,13 @@ type Account struct { MaxSessions *int `json:"max_sessions,omitempty"` SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"` + // RPM 限制(仅 Anthropic OAuth/SetupToken 账号有效) + // 从 extra 字段提取,方便前端显示和编辑 + BaseRPM *int `json:"base_rpm,omitempty"` + RPMStrategy *string `json:"rpm_strategy,omitempty"` + RPMStickyBuffer *int `json:"rpm_sticky_buffer,omitempty"` + UserMsgQueueMode *string `json:"user_msg_queue_mode,omitempty"` + // TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效) // 从 extra 字段提取,方便前端显示和编辑 EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"` @@ -150,6 +180,11 @@ type Account struct { // 从 extra 字段提取,方便前端显示和编辑 EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"` + // 缓存 TTL 强制替换(仅 Anthropic OAuth/SetupToken 账号有效) + // 启用后将所有 cache creation tokens 归入指定的 TTL 类型计费 + CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"` + CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"` + Proxy *Proxy `json:"proxy,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"` @@ -191,6 +226,37 @@ type ProxyWithAccountCount struct { CountryCode string `json:"country_code,omitempty"` Region string `json:"region,omitempty"` City string `json:"city,omitempty"` + QualityStatus string `json:"quality_status,omitempty"` + QualityScore *int `json:"quality_score,omitempty"` + QualityGrade string `json:"quality_grade,omitempty"` + QualitySummary string `json:"quality_summary,omitempty"` + QualityChecked *int64 `json:"quality_checked,omitempty"` +} + +// AdminProxy 是管理员接口使用的 proxy DTO(包含密码等敏感字段)。 +// 注意:普通接口不得使用此 DTO。 +type AdminProxy struct { + Proxy + Password string `json:"password,omitempty"` +} + +// AdminProxyWithAccountCount 是管理员接口使用的带账号统计的 proxy DTO。 +type AdminProxyWithAccountCount struct { + AdminProxy + AccountCount int64 `json:"account_count"` + LatencyMs *int64 `json:"latency_ms,omitempty"` + LatencyStatus string `json:"latency_status,omitempty"` + LatencyMessage string `json:"latency_message,omitempty"` + IPAddress string `json:"ip_address,omitempty"` + Country string `json:"country,omitempty"` + CountryCode string `json:"country_code,omitempty"` + Region string `json:"region,omitempty"` + City string `json:"city,omitempty"` + QualityStatus string `json:"quality_status,omitempty"` + QualityScore *int `json:"quality_score,omitempty"` + QualityGrade string `json:"quality_grade,omitempty"` + QualitySummary string `json:"quality_summary,omitempty"` + QualityChecked *int64 `json:"quality_checked,omitempty"` } type ProxyAccountSummary struct { @@ -261,18 +327,24 @@ type UsageLog struct { ActualCost float64 `json:"actual_cost"` RateMultiplier float64 `json:"rate_multiplier"` - BillingType int8 `json:"billing_type"` - Stream bool `json:"stream"` - DurationMs *int `json:"duration_ms"` - FirstTokenMs *int `json:"first_token_ms"` + BillingType int8 `json:"billing_type"` + RequestType string `json:"request_type"` + Stream bool `json:"stream"` + OpenAIWSMode bool `json:"openai_ws_mode"` + DurationMs *int `json:"duration_ms"` + FirstTokenMs *int `json:"first_token_ms"` // 图片生成字段 ImageCount int `json:"image_count"` ImageSize *string `json:"image_size"` + MediaType *string `json:"media_type"` // User-Agent UserAgent *string `json:"user_agent"` + // Cache TTL Override 标记 + CacheTTLOverridden bool `json:"cache_ttl_overridden"` + CreatedAt time.Time `json:"created_at"` User *User `json:"user,omitempty"` @@ -303,6 +375,7 @@ type UsageCleanupFilters struct { AccountID *int64 `json:"account_id,omitempty"` GroupID *int64 `json:"group_id,omitempty"` Model *string `json:"model,omitempty"` + RequestType *string `json:"request_type,omitempty"` Stream *bool `json:"stream,omitempty"` BillingType *int8 `json:"billing_type,omitempty"` } @@ -374,9 +447,12 @@ type AdminUserSubscription struct { type BulkAssignResult struct { SuccessCount int `json:"success_count"` + CreatedCount int `json:"created_count"` + ReusedCount int `json:"reused_count"` FailedCount int `json:"failed_count"` Subscriptions []AdminUserSubscription `json:"subscriptions"` Errors []string `json:"errors"` + Statuses map[string]string `json:"statuses,omitempty"` } // PromoCode 注册优惠码 diff --git a/backend/internal/handler/failover_loop.go b/backend/internal/handler/failover_loop.go new file mode 100644 index 00000000..b2583301 --- /dev/null +++ b/backend/internal/handler/failover_loop.go @@ -0,0 +1,174 @@ +package handler + +import ( + "context" + "net/http" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/service" + "go.uber.org/zap" +) + +// TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。 +// GatewayService 隐式实现此接口。 +type TempUnscheduler interface { + TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *service.UpstreamFailoverError) +} + +// FailoverAction 表示 failover 错误处理后的下一步动作 +type FailoverAction int + +const ( + // FailoverContinue 继续循环(同账号重试或切换账号,调用方统一 continue) + FailoverContinue FailoverAction = iota + // FailoverExhausted 切换次数耗尽(调用方应返回错误响应) + FailoverExhausted + // FailoverCanceled context 已取消(调用方应直接 return) + FailoverCanceled +) + +const ( + // maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误) + maxSameAccountRetries = 2 + // sameAccountRetryDelay 同账号重试间隔 + sameAccountRetryDelay = 500 * time.Millisecond + // singleAccountBackoffDelay 单账号分组 503 退避重试固定延时。 + // Service 层在 SingleAccountRetry 模式下已做充分原地重试(最多 3 次、总等待 30s), + // Handler 层只需短暂间隔后重新进入 Service 层即可。 + singleAccountBackoffDelay = 2 * time.Second +) + +// FailoverState 跨循环迭代共享的 failover 状态 +type FailoverState struct { + SwitchCount int + MaxSwitches int + FailedAccountIDs map[int64]struct{} + SameAccountRetryCount map[int64]int + LastFailoverErr *service.UpstreamFailoverError + ForceCacheBilling bool + hasBoundSession bool +} + +// NewFailoverState 创建 failover 状态 +func NewFailoverState(maxSwitches int, hasBoundSession bool) *FailoverState { + return &FailoverState{ + MaxSwitches: maxSwitches, + FailedAccountIDs: make(map[int64]struct{}), + SameAccountRetryCount: make(map[int64]int), + hasBoundSession: hasBoundSession, + } +} + +// HandleFailoverError 处理 UpstreamFailoverError,返回下一步动作。 +// 包含:缓存计费判断、同账号重试、临时封禁、切换计数、Antigravity 延时。 +func (s *FailoverState) HandleFailoverError( + ctx context.Context, + gatewayService TempUnscheduler, + accountID int64, + platform string, + failoverErr *service.UpstreamFailoverError, +) FailoverAction { + s.LastFailoverErr = failoverErr + + // 缓存计费判断 + if needForceCacheBilling(s.hasBoundSession, failoverErr) { + s.ForceCacheBilling = true + } + + // 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试 + if failoverErr.RetryableOnSameAccount && s.SameAccountRetryCount[accountID] < maxSameAccountRetries { + s.SameAccountRetryCount[accountID]++ + logger.FromContext(ctx).Warn("gateway.failover_same_account_retry", + zap.Int64("account_id", accountID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("same_account_retry_count", s.SameAccountRetryCount[accountID]), + zap.Int("same_account_retry_max", maxSameAccountRetries), + ) + if !sleepWithContext(ctx, sameAccountRetryDelay) { + return FailoverCanceled + } + return FailoverContinue + } + + // 同账号重试用尽,执行临时封禁 + if failoverErr.RetryableOnSameAccount { + gatewayService.TempUnscheduleRetryableError(ctx, accountID, failoverErr) + } + + // 加入失败列表 + s.FailedAccountIDs[accountID] = struct{}{} + + // 检查是否耗尽 + if s.SwitchCount >= s.MaxSwitches { + return FailoverExhausted + } + + // 递增切换计数 + s.SwitchCount++ + logger.FromContext(ctx).Warn("gateway.failover_switch_account", + zap.Int64("account_id", accountID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", s.SwitchCount), + zap.Int("max_switches", s.MaxSwitches), + ) + + // Antigravity 平台换号线性递增延时 + if platform == service.PlatformAntigravity { + delay := time.Duration(s.SwitchCount-1) * time.Second + if !sleepWithContext(ctx, delay) { + return FailoverCanceled + } + } + + return FailoverContinue +} + +// HandleSelectionExhausted 处理选号失败(所有候选账号都在排除列表中)时的退避重试决策。 +// 针对 Antigravity 单账号分组的 503 (MODEL_CAPACITY_EXHAUSTED) 场景: +// 清除排除列表、等待退避后重新选号。 +// +// 返回 FailoverContinue 时,调用方应设置 SingleAccountRetry context 并 continue。 +// 返回 FailoverExhausted 时,调用方应返回错误响应。 +// 返回 FailoverCanceled 时,调用方应直接 return。 +func (s *FailoverState) HandleSelectionExhausted(ctx context.Context) FailoverAction { + if s.LastFailoverErr != nil && + s.LastFailoverErr.StatusCode == http.StatusServiceUnavailable && + s.SwitchCount <= s.MaxSwitches { + + logger.FromContext(ctx).Warn("gateway.failover_single_account_backoff", + zap.Duration("backoff_delay", singleAccountBackoffDelay), + zap.Int("switch_count", s.SwitchCount), + zap.Int("max_switches", s.MaxSwitches), + ) + if !sleepWithContext(ctx, singleAccountBackoffDelay) { + return FailoverCanceled + } + logger.FromContext(ctx).Warn("gateway.failover_single_account_retry", + zap.Int("switch_count", s.SwitchCount), + zap.Int("max_switches", s.MaxSwitches), + ) + s.FailedAccountIDs = make(map[int64]struct{}) + return FailoverContinue + } + return FailoverExhausted +} + +// needForceCacheBilling 判断 failover 时是否需要强制缓存计费。 +// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费。 +func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool { + return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling) +} + +// sleepWithContext 等待指定时长,返回 false 表示 context 已取消。 +func sleepWithContext(ctx context.Context, d time.Duration) bool { + if d <= 0 { + return true + } + select { + case <-ctx.Done(): + return false + case <-time.After(d): + return true + } +} diff --git a/backend/internal/handler/failover_loop_test.go b/backend/internal/handler/failover_loop_test.go new file mode 100644 index 00000000..5a41b2dd --- /dev/null +++ b/backend/internal/handler/failover_loop_test.go @@ -0,0 +1,732 @@ +package handler + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Mock +// --------------------------------------------------------------------------- + +// mockTempUnscheduler 记录 TempUnscheduleRetryableError 的调用信息。 +type mockTempUnscheduler struct { + calls []tempUnscheduleCall +} + +type tempUnscheduleCall struct { + accountID int64 + failoverErr *service.UpstreamFailoverError +} + +func (m *mockTempUnscheduler) TempUnscheduleRetryableError(_ context.Context, accountID int64, failoverErr *service.UpstreamFailoverError) { + m.calls = append(m.calls, tempUnscheduleCall{accountID: accountID, failoverErr: failoverErr}) +} + +// --------------------------------------------------------------------------- +// Helper +// --------------------------------------------------------------------------- + +func newTestFailoverErr(statusCode int, retryable, forceBilling bool) *service.UpstreamFailoverError { + return &service.UpstreamFailoverError{ + StatusCode: statusCode, + RetryableOnSameAccount: retryable, + ForceCacheBilling: forceBilling, + } +} + +// --------------------------------------------------------------------------- +// NewFailoverState 测试 +// --------------------------------------------------------------------------- + +func TestNewFailoverState(t *testing.T) { + t.Run("初始化字段正确", func(t *testing.T) { + fs := NewFailoverState(5, true) + require.Equal(t, 5, fs.MaxSwitches) + require.Equal(t, 0, fs.SwitchCount) + require.NotNil(t, fs.FailedAccountIDs) + require.Empty(t, fs.FailedAccountIDs) + require.NotNil(t, fs.SameAccountRetryCount) + require.Empty(t, fs.SameAccountRetryCount) + require.Nil(t, fs.LastFailoverErr) + require.False(t, fs.ForceCacheBilling) + require.True(t, fs.hasBoundSession) + }) + + t.Run("无绑定会话", func(t *testing.T) { + fs := NewFailoverState(3, false) + require.Equal(t, 3, fs.MaxSwitches) + require.False(t, fs.hasBoundSession) + }) + + t.Run("零最大切换次数", func(t *testing.T) { + fs := NewFailoverState(0, false) + require.Equal(t, 0, fs.MaxSwitches) + }) +} + +// --------------------------------------------------------------------------- +// sleepWithContext 测试 +// --------------------------------------------------------------------------- + +func TestSleepWithContext(t *testing.T) { + t.Run("零时长立即返回true", func(t *testing.T) { + start := time.Now() + ok := sleepWithContext(context.Background(), 0) + require.True(t, ok) + require.Less(t, time.Since(start), 50*time.Millisecond) + }) + + t.Run("负时长立即返回true", func(t *testing.T) { + start := time.Now() + ok := sleepWithContext(context.Background(), -1*time.Second) + require.True(t, ok) + require.Less(t, time.Since(start), 50*time.Millisecond) + }) + + t.Run("正常等待后返回true", func(t *testing.T) { + start := time.Now() + ok := sleepWithContext(context.Background(), 50*time.Millisecond) + elapsed := time.Since(start) + require.True(t, ok) + require.GreaterOrEqual(t, elapsed, 40*time.Millisecond) + require.Less(t, elapsed, 500*time.Millisecond) + }) + + t.Run("已取消context立即返回false", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + start := time.Now() + ok := sleepWithContext(ctx, 5*time.Second) + require.False(t, ok) + require.Less(t, time.Since(start), 50*time.Millisecond) + }) + + t.Run("等待期间context取消返回false", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(30 * time.Millisecond) + cancel() + }() + + start := time.Now() + ok := sleepWithContext(ctx, 5*time.Second) + elapsed := time.Since(start) + require.False(t, ok) + require.Less(t, elapsed, 500*time.Millisecond) + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — 基本切换流程 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_BasicSwitch(t *testing.T) { + t.Run("非重试错误_非Antigravity_直接切换", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, false, false) + + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SwitchCount) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + require.Equal(t, err, fs.LastFailoverErr) + require.False(t, fs.ForceCacheBilling) + require.Empty(t, mock.calls, "不应调用 TempUnschedule") + }) + + t.Run("非重试错误_Antigravity_第一次切换无延迟", func(t *testing.T) { + // switchCount 从 0→1 时,sleepFailoverDelay(ctx, 1) 的延时 = (1-1)*1s = 0 + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, false, false) + + start := time.Now() + action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err) + elapsed := time.Since(start) + + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SwitchCount) + require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟应为 0") + }) + + t.Run("非重试错误_Antigravity_第二次切换有1秒延迟", func(t *testing.T) { + // switchCount 从 1→2 时,sleepFailoverDelay(ctx, 2) 的延时 = (2-1)*1s = 1s + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + fs.SwitchCount = 1 // 模拟已切换一次 + + err := newTestFailoverErr(500, false, false) + start := time.Now() + action := fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err) + elapsed := time.Since(start) + + require.Equal(t, FailoverContinue, action) + require.Equal(t, 2, fs.SwitchCount) + require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟应约 1s") + require.Less(t, elapsed, 3*time.Second) + }) + + t.Run("连续切换直到耗尽", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(2, false) + + // 第一次切换:0→1 + err1 := newTestFailoverErr(500, false, false) + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SwitchCount) + + // 第二次切换:1→2 + err2 := newTestFailoverErr(502, false, false) + action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 2, fs.SwitchCount) + + // 第三次已耗尽:SwitchCount(2) >= MaxSwitches(2) + err3 := newTestFailoverErr(503, false, false) + action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3) + require.Equal(t, FailoverExhausted, action) + require.Equal(t, 2, fs.SwitchCount, "耗尽时不应继续递增") + + // 验证失败账号列表 + require.Len(t, fs.FailedAccountIDs, 3) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + require.Contains(t, fs.FailedAccountIDs, int64(200)) + require.Contains(t, fs.FailedAccountIDs, int64(300)) + + // LastFailoverErr 应为最后一次的错误 + require.Equal(t, err3, fs.LastFailoverErr) + }) + + t.Run("MaxSwitches为0时首次即耗尽", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(0, false) + err := newTestFailoverErr(500, false, false) + + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverExhausted, action) + require.Equal(t, 0, fs.SwitchCount) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — 缓存计费 (ForceCacheBilling) +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_CacheBilling(t *testing.T) { + t.Run("hasBoundSession为true时设置ForceCacheBilling", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, true) // hasBoundSession=true + err := newTestFailoverErr(500, false, false) + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.True(t, fs.ForceCacheBilling) + }) + + t.Run("failoverErr.ForceCacheBilling为true时设置", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, false, true) // ForceCacheBilling=true + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.True(t, fs.ForceCacheBilling) + }) + + t.Run("两者均为false时不设置", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, false, false) + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.False(t, fs.ForceCacheBilling) + }) + + t.Run("一旦设置不会被后续错误重置", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + + // 第一次:ForceCacheBilling=true → 设置 + err1 := newTestFailoverErr(500, false, true) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1) + require.True(t, fs.ForceCacheBilling) + + // 第二次:ForceCacheBilling=false → 仍然保持 true + err2 := newTestFailoverErr(502, false, false) + fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2) + require.True(t, fs.ForceCacheBilling, "ForceCacheBilling 一旦设置不应被重置") + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — 同账号重试 (RetryableOnSameAccount) +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_SameAccountRetry(t *testing.T) { + t.Run("第一次重试返回FailoverContinue", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(400, true, false) + + start := time.Now() + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + elapsed := time.Since(start) + + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[100]) + require.Equal(t, 0, fs.SwitchCount, "同账号重试不应增加切换计数") + require.NotContains(t, fs.FailedAccountIDs, int64(100), "同账号重试不应加入失败列表") + require.Empty(t, mock.calls, "同账号重试期间不应调用 TempUnschedule") + // 验证等待了 sameAccountRetryDelay (500ms) + require.GreaterOrEqual(t, elapsed, 400*time.Millisecond) + require.Less(t, elapsed, 2*time.Second) + }) + + t.Run("第二次重试仍返回FailoverContinue", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(400, true, false) + + // 第一次 + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[100]) + + // 第二次 + action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 2, fs.SameAccountRetryCount[100]) + + require.Empty(t, mock.calls, "两次重试期间均不应调用 TempUnschedule") + }) + + t.Run("第三次重试耗尽_触发TempUnschedule并切换", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(400, true, false) + + // 第一次、第二次重试 + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, 2, fs.SameAccountRetryCount[100]) + + // 第三次:重试已达到 maxSameAccountRetries(2),应切换账号 + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SwitchCount) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + + // 验证 TempUnschedule 被调用 + require.Len(t, mock.calls, 1) + require.Equal(t, int64(100), mock.calls[0].accountID) + require.Equal(t, err, mock.calls[0].failoverErr) + }) + + t.Run("不同账号独立跟踪重试次数", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(5, false) + err := newTestFailoverErr(400, true, false) + + // 账号 100 第一次重试 + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[100]) + + // 账号 200 第一次重试(独立计数) + action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[200]) + require.Equal(t, 1, fs.SameAccountRetryCount[100], "账号 100 的计数不应受影响") + }) + + t.Run("重试耗尽后再次遇到同账号_直接切换", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(5, false) + err := newTestFailoverErr(400, true, false) + + // 耗尽账号 100 的重试 + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + // 第三次: 重试耗尽 → 切换 + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + + // 再次遇到账号 100,计数仍为 2,条件不满足 → 直接切换 + action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Len(t, mock.calls, 2, "第二次耗尽也应调用 TempUnschedule") + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — TempUnschedule 调用验证 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_TempUnschedule(t *testing.T) { + t.Run("非重试错误不调用TempUnschedule", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, false, false) // RetryableOnSameAccount=false + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Empty(t, mock.calls) + }) + + t.Run("重试错误耗尽后调用TempUnschedule_传入正确参数", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(502, true, false) + + // 耗尽重试 + fs.HandleFailoverError(context.Background(), mock, 42, "openai", err) + fs.HandleFailoverError(context.Background(), mock, 42, "openai", err) + fs.HandleFailoverError(context.Background(), mock, 42, "openai", err) + + require.Len(t, mock.calls, 1) + require.Equal(t, int64(42), mock.calls[0].accountID) + require.Equal(t, 502, mock.calls[0].failoverErr.StatusCode) + require.True(t, mock.calls[0].failoverErr.RetryableOnSameAccount) + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — Context 取消 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_ContextCanceled(t *testing.T) { + t.Run("同账号重试sleep期间context取消", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(400, true, false) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // 立即取消 + + start := time.Now() + action := fs.HandleFailoverError(ctx, mock, 100, "openai", err) + elapsed := time.Since(start) + + require.Equal(t, FailoverCanceled, action) + require.Less(t, elapsed, 100*time.Millisecond, "应立即返回") + // 重试计数仍应递增 + require.Equal(t, 1, fs.SameAccountRetryCount[100]) + }) + + t.Run("Antigravity延迟期间context取消", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + fs.SwitchCount = 1 // 下一次 switchCount=2 → delay = 1s + err := newTestFailoverErr(500, false, false) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // 立即取消 + + start := time.Now() + action := fs.HandleFailoverError(ctx, mock, 100, service.PlatformAntigravity, err) + elapsed := time.Since(start) + + require.Equal(t, FailoverCanceled, action) + require.Less(t, elapsed, 100*time.Millisecond, "应立即返回而非等待 1s") + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — FailedAccountIDs 跟踪 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_FailedAccountIDs(t *testing.T) { + t.Run("切换时添加到失败列表", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false)) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + + fs.HandleFailoverError(context.Background(), mock, 200, "openai", newTestFailoverErr(502, false, false)) + require.Contains(t, fs.FailedAccountIDs, int64(200)) + require.Len(t, fs.FailedAccountIDs, 2) + }) + + t.Run("耗尽时也添加到失败列表", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(0, false) + + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false)) + require.Equal(t, FailoverExhausted, action) + require.Contains(t, fs.FailedAccountIDs, int64(100)) + }) + + t.Run("同账号重试期间不添加到失败列表", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(400, true, false)) + require.Equal(t, FailoverContinue, action) + require.NotContains(t, fs.FailedAccountIDs, int64(100)) + }) + + t.Run("同一账号多次切换不重复添加", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(5, false) + + fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false)) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", newTestFailoverErr(500, false, false)) + require.Len(t, fs.FailedAccountIDs, 1, "map 天然去重") + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — LastFailoverErr 更新 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_LastFailoverErr(t *testing.T) { + t.Run("每次调用都更新LastFailoverErr", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + + err1 := newTestFailoverErr(500, false, false) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1) + require.Equal(t, err1, fs.LastFailoverErr) + + err2 := newTestFailoverErr(502, false, false) + fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2) + require.Equal(t, err2, fs.LastFailoverErr) + }) + + t.Run("同账号重试时也更新LastFailoverErr", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + + err := newTestFailoverErr(400, true, false) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, err, fs.LastFailoverErr) + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — 综合集成场景 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_IntegrationScenario(t *testing.T) { + t.Run("模拟完整failover流程_多账号混合重试与切换", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, true) // hasBoundSession=true + + // 1. 账号 100 遇到可重试错误,同账号重试 2 次 + retryErr := newTestFailoverErr(400, true, false) + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) + require.Equal(t, FailoverContinue, action) + require.True(t, fs.ForceCacheBilling, "hasBoundSession=true 应设置 ForceCacheBilling") + + action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) + require.Equal(t, FailoverContinue, action) + + // 2. 账号 100 重试耗尽 → TempUnschedule + 切换 + action = fs.HandleFailoverError(context.Background(), mock, 100, "openai", retryErr) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SwitchCount) + require.Len(t, mock.calls, 1) + + // 3. 账号 200 遇到不可重试错误 → 直接切换 + switchErr := newTestFailoverErr(500, false, false) + action = fs.HandleFailoverError(context.Background(), mock, 200, "openai", switchErr) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 2, fs.SwitchCount) + + // 4. 账号 300 遇到不可重试错误 → 再切换 + action = fs.HandleFailoverError(context.Background(), mock, 300, "openai", switchErr) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 3, fs.SwitchCount) + + // 5. 账号 400 → 已耗尽 (SwitchCount=3 >= MaxSwitches=3) + action = fs.HandleFailoverError(context.Background(), mock, 400, "openai", switchErr) + require.Equal(t, FailoverExhausted, action) + + // 最终状态验证 + require.Equal(t, 3, fs.SwitchCount, "耗尽时不再递增") + require.Len(t, fs.FailedAccountIDs, 4, "4个不同账号都在失败列表中") + require.True(t, fs.ForceCacheBilling) + require.Len(t, mock.calls, 1, "只有账号 100 触发了 TempUnschedule") + }) + + t.Run("模拟Antigravity平台完整流程", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(2, false) + + err := newTestFailoverErr(500, false, false) + + // 第一次切换:delay = 0s + start := time.Now() + action := fs.HandleFailoverError(context.Background(), mock, 100, service.PlatformAntigravity, err) + elapsed := time.Since(start) + require.Equal(t, FailoverContinue, action) + require.Less(t, elapsed, 200*time.Millisecond, "第一次切换延迟为 0") + + // 第二次切换:delay = 1s + start = time.Now() + action = fs.HandleFailoverError(context.Background(), mock, 200, service.PlatformAntigravity, err) + elapsed = time.Since(start) + require.Equal(t, FailoverContinue, action) + require.GreaterOrEqual(t, elapsed, 800*time.Millisecond, "第二次切换延迟约 1s") + + // 第三次:耗尽(无延迟,因为在检查延迟之前就返回了) + start = time.Now() + action = fs.HandleFailoverError(context.Background(), mock, 300, service.PlatformAntigravity, err) + elapsed = time.Since(start) + require.Equal(t, FailoverExhausted, action) + require.Less(t, elapsed, 200*time.Millisecond, "耗尽时不应有延迟") + }) + + t.Run("ForceCacheBilling通过错误标志设置", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) // hasBoundSession=false + + // 第一次:ForceCacheBilling=false + err1 := newTestFailoverErr(500, false, false) + fs.HandleFailoverError(context.Background(), mock, 100, "openai", err1) + require.False(t, fs.ForceCacheBilling) + + // 第二次:ForceCacheBilling=true(Antigravity 粘性会话切换) + err2 := newTestFailoverErr(500, false, true) + fs.HandleFailoverError(context.Background(), mock, 200, "openai", err2) + require.True(t, fs.ForceCacheBilling, "错误标志应触发 ForceCacheBilling") + + // 第三次:ForceCacheBilling=false,但状态仍保持 true + err3 := newTestFailoverErr(500, false, false) + fs.HandleFailoverError(context.Background(), mock, 300, "openai", err3) + require.True(t, fs.ForceCacheBilling, "不应重置") + }) +} + +// --------------------------------------------------------------------------- +// HandleFailoverError — 边界条件 +// --------------------------------------------------------------------------- + +func TestHandleFailoverError_EdgeCases(t *testing.T) { + t.Run("StatusCode为0的错误也能正常处理", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(0, false, false) + + action := fs.HandleFailoverError(context.Background(), mock, 100, "openai", err) + require.Equal(t, FailoverContinue, action) + }) + + t.Run("AccountID为0也能正常跟踪", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, true, false) + + action := fs.HandleFailoverError(context.Background(), mock, 0, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[0]) + }) + + t.Run("负AccountID也能正常跟踪", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + err := newTestFailoverErr(500, true, false) + + action := fs.HandleFailoverError(context.Background(), mock, -1, "openai", err) + require.Equal(t, FailoverContinue, action) + require.Equal(t, 1, fs.SameAccountRetryCount[-1]) + }) + + t.Run("空平台名称不触发Antigravity延迟", func(t *testing.T) { + mock := &mockTempUnscheduler{} + fs := NewFailoverState(3, false) + fs.SwitchCount = 1 + err := newTestFailoverErr(500, false, false) + + start := time.Now() + action := fs.HandleFailoverError(context.Background(), mock, 100, "", err) + elapsed := time.Since(start) + + require.Equal(t, FailoverContinue, action) + require.Less(t, elapsed, 200*time.Millisecond, "空平台不应触发 Antigravity 延迟") + }) +} + +// --------------------------------------------------------------------------- +// HandleSelectionExhausted 测试 +// --------------------------------------------------------------------------- + +func TestHandleSelectionExhausted(t *testing.T) { + t.Run("无LastFailoverErr时返回Exhausted", func(t *testing.T) { + fs := NewFailoverState(3, false) + // LastFailoverErr 为 nil + + action := fs.HandleSelectionExhausted(context.Background()) + require.Equal(t, FailoverExhausted, action) + }) + + t.Run("非503错误返回Exhausted", func(t *testing.T) { + fs := NewFailoverState(3, false) + fs.LastFailoverErr = newTestFailoverErr(500, false, false) + + action := fs.HandleSelectionExhausted(context.Background()) + require.Equal(t, FailoverExhausted, action) + }) + + t.Run("503且未耗尽_等待后返回Continue并清除失败列表", func(t *testing.T) { + fs := NewFailoverState(3, false) + fs.LastFailoverErr = newTestFailoverErr(503, false, false) + fs.FailedAccountIDs[100] = struct{}{} + fs.SwitchCount = 1 + + start := time.Now() + action := fs.HandleSelectionExhausted(context.Background()) + elapsed := time.Since(start) + + require.Equal(t, FailoverContinue, action) + require.Empty(t, fs.FailedAccountIDs, "应清除失败账号列表") + require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "应等待约 2s") + require.Less(t, elapsed, 5*time.Second) + }) + + t.Run("503但SwitchCount已超过MaxSwitches_返回Exhausted", func(t *testing.T) { + fs := NewFailoverState(2, false) + fs.LastFailoverErr = newTestFailoverErr(503, false, false) + fs.SwitchCount = 3 // > MaxSwitches(2) + + start := time.Now() + action := fs.HandleSelectionExhausted(context.Background()) + elapsed := time.Since(start) + + require.Equal(t, FailoverExhausted, action) + require.Less(t, elapsed, 100*time.Millisecond, "不应等待") + }) + + t.Run("503但context已取消_返回Canceled", func(t *testing.T) { + fs := NewFailoverState(3, false) + fs.LastFailoverErr = newTestFailoverErr(503, false, false) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + start := time.Now() + action := fs.HandleSelectionExhausted(ctx) + elapsed := time.Since(start) + + require.Equal(t, FailoverCanceled, action) + require.Less(t, elapsed, 100*time.Millisecond, "应立即返回") + }) + + t.Run("503且SwitchCount等于MaxSwitches_仍可重试", func(t *testing.T) { + fs := NewFailoverState(2, false) + fs.LastFailoverErr = newTestFailoverErr(503, false, false) + fs.SwitchCount = 2 // == MaxSwitches,条件是 <=,仍可重试 + + action := fs.HandleSelectionExhausted(context.Background()) + require.Equal(t, FailoverContinue, action) + }) +} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index c2b6bf09..1c0ef8e6 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -6,10 +6,10 @@ import ( "encoding/json" "errors" "fmt" - "io" - "log" "net/http" + "strconv" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -18,14 +18,22 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" + "go.uber.org/zap" ) +const gatewayCompatibilityMetricsLogInterval = 1024 + +var gatewayCompatibilityMetricsLogCounter atomic.Uint64 + // GatewayHandler handles API gateway requests type GatewayHandler struct { gatewayService *service.GatewayService @@ -35,10 +43,14 @@ type GatewayHandler struct { billingCacheService *service.BillingCacheService usageService *service.UsageService apiKeyService *service.APIKeyService + usageRecordWorkerPool *service.UsageRecordWorkerPool errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper + userMsgQueueHelper *UserMsgQueueHelper maxAccountSwitches int maxAccountSwitchesGemini int + cfg *config.Config + settingService *service.SettingService } // NewGatewayHandler creates a new GatewayHandler @@ -51,8 +63,11 @@ func NewGatewayHandler( billingCacheService *service.BillingCacheService, usageService *service.UsageService, apiKeyService *service.APIKeyService, + usageRecordWorkerPool *service.UsageRecordWorkerPool, errorPassthroughService *service.ErrorPassthroughService, + userMsgQueueService *service.UserMessageQueueService, cfg *config.Config, + settingService *service.SettingService, ) *GatewayHandler { pingInterval := time.Duration(0) maxAccountSwitches := 10 @@ -66,6 +81,13 @@ func NewGatewayHandler( maxAccountSwitchesGemini = cfg.Gateway.MaxAccountSwitchesGemini } } + + // 初始化用户消息串行队列 helper + var umqHelper *UserMsgQueueHelper + if userMsgQueueService != nil && cfg != nil { + umqHelper = NewUserMsgQueueHelper(userMsgQueueService, SSEPingFormatClaude, pingInterval) + } + return &GatewayHandler{ gatewayService: gatewayService, geminiCompatService: geminiCompatService, @@ -74,10 +96,14 @@ func NewGatewayHandler( billingCacheService: billingCacheService, usageService: usageService, apiKeyService: apiKeyService, + usageRecordWorkerPool: usageRecordWorkerPool, errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), + userMsgQueueHelper: umqHelper, maxAccountSwitches: maxAccountSwitches, maxAccountSwitchesGemini: maxAccountSwitchesGemini, + cfg: cfg, + settingService: settingService, } } @@ -96,9 +122,17 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") return } + reqLog := requestLogger( + c, + "handler.gateway.messages", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + defer h.maybeLogCompatibilityFallbackMetrics(reqLog) // 读取请求体 - body, err := io.ReadAll(c.Request.Body) + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) @@ -122,20 +156,26 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } reqModel := parsedReq.Model reqStream := parsedReq.Stream + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) // 设置 max_tokens=1 + haiku 探测请求标识到 context 中 // 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断 if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) { - ctx := context.WithValue(c.Request.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true) + ctx := service.WithIsMaxTokensOneHaikuRequest(c.Request.Context(), true, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) } - // 检查是否为 Claude Code 客户端,设置到 context 中 - SetClaudeCodeClientContext(c, body) + // 检查是否为 Claude Code 客户端,设置到 context 中(复用已解析请求,避免二次反序列化)。 + SetClaudeCodeClientContext(c, body, parsedReq) isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context()) + // 版本检查:仅对 Claude Code 客户端,拒绝低于最低版本的请求 + if !h.checkClaudeCodeVersion(c) { + return + } + // 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用 - c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled)) + c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled())) setOpsRequestContext(c, reqModel, reqStream, body) @@ -161,9 +201,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) waitCounted := false if err != nil { - log.Printf("Increment wait count failed: %v", err) + reqLog.Warn("gateway.user_wait_counter_increment_failed", zap.Error(err)) // On error, allow request to proceed } else if !canWait { + reqLog.Info("gateway.user_wait_queue_full", zap.Int("max_wait", maxWait)) h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") return } @@ -180,7 +221,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 1. 首先获取用户并发槽位 userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) if err != nil { - log.Printf("User concurrency acquire failed: %v", err) + reqLog.Warn("gateway.user_slot_acquire_failed", zap.Error(err)) h.handleConcurrencyError(c, err, "user", streamStarted) return } @@ -197,7 +238,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 2. 【新增】Wait后二次检查余额/订阅 if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { - log.Printf("Billing eligibility check failed after wait: %v", err) + reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err)) status, code, message := billingErrorDetails(err) h.handleStreamingAwareError(c, status, code, message, streamStarted) return @@ -227,54 +268,54 @@ func (h *GatewayHandler) Messages(c *gin.Context) { var sessionBoundAccountID int64 if sessionKey != "" { sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) + if sessionBoundAccountID > 0 { + prefetchedGroupID := int64(0) + if apiKey.GroupID != nil { + prefetchedGroupID = *apiKey.GroupID + } + ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled()) + c.Request = c.Request.WithContext(ctx) + } } // 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号 hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 if platform == service.PlatformGemini { - maxAccountSwitches := h.maxAccountSwitchesGemini - switchCount := 0 - failedAccountIDs := make(map[int64]struct{}) - sameAccountRetryCount := make(map[int64]int) // 同账号重试计数 - var lastFailoverErr *service.UpstreamFailoverError - var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 + fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession) // 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。 // 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。 if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) { - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) } for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制 + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") // Gemini 不使用会话限制 if err != nil { - if len(failedAccountIDs) == 0 { + if len(fs.FailedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - // Antigravity 单账号退避重试:分组内没有其他可用账号时, - // 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。 - // 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 - if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches { - if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) { - log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches) - failedAccountIDs = make(map[int64]struct{}) - // 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换 - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) - c.Request = c.Request.WithContext(ctx) - continue + action := fs.HandleSelectionExhausted(c.Request.Context()) + switch action { + case FailoverContinue: + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) + c.Request = c.Request.WithContext(ctx) + continue + case FailoverCanceled: + return + default: // FailoverExhausted + if fs.LastFailoverErr != nil { + h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted) + } else { + h.handleFailoverExhaustedSimple(c, 502, streamStarted) } + return } - if lastFailoverErr != nil { - h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted) - } else { - h.handleFailoverExhaustedSimple(c, 502, streamStarted) - } - return } account := selection.Account - setOpsSelectedAccount(c, account.ID) + setOpsSelectedAccount(c, account.ID, account.Platform) // 检查请求拦截(预热请求、SUGGESTION MODE等) if account.IsInterceptWarmupEnabled() { @@ -302,21 +343,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) { accountWaitCounted := false canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) if err != nil { - log.Printf("Increment account wait count failed: %v", err) + reqLog.Warn("gateway.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) + reqLog.Info("gateway.account_wait_queue_full", + zap.Int64("account_id", account.ID), + zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), + ) h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) return } if err == nil && canWait { accountWaitCounted = true } - // Ensure the wait counter is decremented if we exit before acquiring the slot. - defer func() { + releaseWait := func() { if accountWaitCounted { h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false } - }() + } accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( c, @@ -327,17 +371,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) { &streamStarted, ) if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) + reqLog.Warn("gateway.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + releaseWait() h.handleConcurrencyError(c, err, "account", streamStarted) return } // Slot acquired: no longer waiting in queue. - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } + releaseWait() if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) + reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } } // 账号槽位/等待计数需要在超时或断开时安全回收 @@ -346,8 +388,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 转发请求 - 根据账号平台分流 var result *service.ForwardResult requestCtx := c.Request.Context() - if switchCount > 0 { - requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + if fs.SwitchCount > 0 { + requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled()) } if account.Platform == service.PlatformAntigravity { result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession) @@ -360,68 +402,62 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { - lastFailoverErr = failoverErr - if needForceCacheBilling(hasBoundSession, failoverErr) { - forceCacheBilling = true - } - - // 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试 - if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries { - sameAccountRetryCount[account.ID]++ - log.Printf("Account %d: retryable error %d, same-account retry %d/%d", - account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries) - if !sleepSameAccountRetryDelay(c.Request.Context()) { - return - } + action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) + switch action { + case FailoverContinue: continue - } - - // 同账号重试用尽,执行临时封禁并切换账号 - if failoverErr.RetryableOnSameAccount { - h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr) - } - - failedAccountIDs[account.ID] = struct{}{} - if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted) + case FailoverExhausted: + h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted) + return + case FailoverCanceled: return } - switchCount++ - log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) - if account.Platform == service.PlatformAntigravity { - if !sleepFailoverDelay(c.Request.Context(), switchCount) { - return - } - } - continue } - // 错误响应已在Forward中处理,这里只记录日志 - log.Printf("Forward request failed: %v", err) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Error("gateway.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) return } + // RPM 计数递增(Forward 成功后) + // 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。 + // 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。 + if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 { + if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil { + reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + } + // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) - // 异步记录使用量(subscription已在函数开头获取) - go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, - Account: usedAccount, + Account: account, Subscription: subscription, - UserAgent: ua, + UserAgent: userAgent, IPAddress: clientIP, - ForceCacheBilling: fcb, + ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, }); err != nil { - log.Printf("Record usage failed: %v", err) + logger.L().With( + zap.String("component", "handler.gateway.messages"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", reqModel), + zap.Int64("account_id", account.ID), + ).Error("gateway.record_usage_failed", zap.Error(err)) } - }(result, account, userAgent, clientIP, forceCacheBilling) + }) return } } @@ -437,49 +473,41 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。 // 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。 if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), currentAPIKey.GroupID) { - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) } for { - maxAccountSwitches := h.maxAccountSwitches - switchCount := 0 - failedAccountIDs := make(map[int64]struct{}) - sameAccountRetryCount := make(map[int64]int) // 同账号重试计数 - var lastFailoverErr *service.UpstreamFailoverError + fs := NewFailoverState(h.maxAccountSwitches, hasBoundSession) retryWithFallback := false - var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 for { // 选择支持该模型的账号 - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID) if err != nil { - if len(failedAccountIDs) == 0 { + if len(fs.FailedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - // Antigravity 单账号退避重试:分组内没有其他可用账号时, - // 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。 - // 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 - if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches { - if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) { - log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches) - failedAccountIDs = make(map[int64]struct{}) - // 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换 - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) - c.Request = c.Request.WithContext(ctx) - continue + action := fs.HandleSelectionExhausted(c.Request.Context()) + switch action { + case FailoverContinue: + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) + c.Request = c.Request.WithContext(ctx) + continue + case FailoverCanceled: + return + default: // FailoverExhausted + if fs.LastFailoverErr != nil { + h.handleFailoverExhausted(c, fs.LastFailoverErr, platform, streamStarted) + } else { + h.handleFailoverExhaustedSimple(c, 502, streamStarted) } + return } - if lastFailoverErr != nil { - h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted) - } else { - h.handleFailoverExhaustedSimple(c, 502, streamStarted) - } - return } account := selection.Account - setOpsSelectedAccount(c, account.ID) + setOpsSelectedAccount(c, account.ID, account.Platform) // 检查请求拦截(预热请求、SUGGESTION MODE等) if account.IsInterceptWarmupEnabled() { @@ -507,20 +535,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) { accountWaitCounted := false canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) if err != nil { - log.Printf("Increment account wait count failed: %v", err) + reqLog.Warn("gateway.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) + reqLog.Info("gateway.account_wait_queue_full", + zap.Int64("account_id", account.ID), + zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), + ) h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) return } if err == nil && canWait { accountWaitCounted = true } - defer func() { + releaseWait := func() { if accountWaitCounted { h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false } - }() + } accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( c, @@ -531,50 +563,117 @@ func (h *GatewayHandler) Messages(c *gin.Context) { &streamStarted, ) if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) + reqLog.Warn("gateway.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + releaseWait() h.handleConcurrencyError(c, err, "account", streamStarted) return } - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } + // Slot acquired: no longer waiting in queue. + releaseWait() if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) + reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } } // 账号槽位/等待计数需要在超时或断开时安全回收 accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + // ===== 用户消息串行队列 START ===== + var queueRelease func() + umqMode := h.getUserMsgQueueMode(account, parsedReq) + + switch umqMode { + case config.UMQModeSerialize: + // 串行模式:获取锁 + RPM 延迟 + 释放(当前行为不变) + baseRPM := account.GetBaseRPM() + release, qErr := h.userMsgQueueHelper.AcquireWithWait( + c, account.ID, baseRPM, reqStream, &streamStarted, + h.cfg.Gateway.UserMessageQueue.WaitTimeout(), + reqLog, + ) + if qErr != nil { + // fail-open: 记录 warn,不阻止请求 + reqLog.Warn("gateway.umq_acquire_failed", + zap.Int64("account_id", account.ID), + zap.Error(qErr), + ) + } else { + queueRelease = release + } + + case config.UMQModeThrottle: + // 软性限速:仅施加 RPM 自适应延迟,不阻塞并发 + baseRPM := account.GetBaseRPM() + if tErr := h.userMsgQueueHelper.ThrottleWithPing( + c, account.ID, baseRPM, reqStream, &streamStarted, + h.cfg.Gateway.UserMessageQueue.WaitTimeout(), + reqLog, + ); tErr != nil { + reqLog.Warn("gateway.umq_throttle_failed", + zap.Int64("account_id", account.ID), + zap.Error(tErr), + ) + } + + default: + if umqMode != "" { + reqLog.Warn("gateway.umq_unknown_mode", + zap.String("mode", umqMode), + zap.Int64("account_id", account.ID), + ) + } + } + + // 用 wrapReleaseOnDone 确保 context 取消时自动释放(仅 serialize 模式有 queueRelease) + queueRelease = wrapReleaseOnDone(c.Request.Context(), queueRelease) + // 注入回调到 ParsedRequest:使用外层 wrapper 以便提前清理 AfterFunc + parsedReq.OnUpstreamAccepted = queueRelease + // ===== 用户消息串行队列 END ===== + // 转发请求 - 根据账号平台分流 var result *service.ForwardResult requestCtx := c.Request.Context() - if switchCount > 0 { - requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + if fs.SwitchCount > 0 { + requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled()) } if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession) } else { result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq) } + + // 兜底释放串行锁(正常情况已通过回调提前释放) + if queueRelease != nil { + queueRelease() + } + // 清理回调引用,防止 failover 重试时旧回调被错误调用 + parsedReq.OnUpstreamAccepted = nil + if accountReleaseFunc != nil { accountReleaseFunc() } if err != nil { var promptTooLongErr *service.PromptTooLongError if errors.As(err, &promptTooLongErr) { - log.Printf("Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v", currentAPIKey.GroupID, fallbackGroupID, fallbackUsed) + reqLog.Warn("gateway.prompt_too_long_from_antigravity", + zap.Any("current_group_id", currentAPIKey.GroupID), + zap.Any("fallback_group_id", fallbackGroupID), + zap.Bool("fallback_used", fallbackUsed), + ) if !fallbackUsed && fallbackGroupID != nil && *fallbackGroupID > 0 { fallbackGroup, err := h.gatewayService.ResolveGroupByID(c.Request.Context(), *fallbackGroupID) if err != nil { - log.Printf("Resolve fallback group failed: %v", err) + reqLog.Warn("gateway.resolve_fallback_group_failed", zap.Int64("fallback_group_id", *fallbackGroupID), zap.Error(err)) _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) return } if fallbackGroup.Platform != service.PlatformAnthropic || fallbackGroup.SubscriptionType == service.SubscriptionTypeSubscription || fallbackGroup.FallbackGroupIDOnInvalidRequest != nil { - log.Printf("Fallback group invalid: group=%d platform=%s subscription=%s", fallbackGroup.ID, fallbackGroup.Platform, fallbackGroup.SubscriptionType) + reqLog.Warn("gateway.fallback_group_invalid", + zap.Int64("fallback_group_id", fallbackGroup.ID), + zap.String("fallback_platform", fallbackGroup.Platform), + zap.String("fallback_subscription_type", fallbackGroup.SubscriptionType), + ) _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) return } @@ -584,7 +683,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleStreamingAwareError(c, status, code, message, streamStarted) return } - // 兜底重试按“直接请求兜底分组”处理:清除强制平台,允许按分组平台调度 + // 兜底重试按"直接请求兜底分组"处理:清除强制平台,允许按分组平台调度 ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, "") c.Request = c.Request.WithContext(ctx) currentAPIKey = fallbackAPIKey @@ -598,68 +697,62 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { - lastFailoverErr = failoverErr - if needForceCacheBilling(hasBoundSession, failoverErr) { - forceCacheBilling = true - } - - // 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试 - if failoverErr.RetryableOnSameAccount && sameAccountRetryCount[account.ID] < maxSameAccountRetries { - sameAccountRetryCount[account.ID]++ - log.Printf("Account %d: retryable error %d, same-account retry %d/%d", - account.ID, failoverErr.StatusCode, sameAccountRetryCount[account.ID], maxSameAccountRetries) - if !sleepSameAccountRetryDelay(c.Request.Context()) { - return - } + action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) + switch action { + case FailoverContinue: continue - } - - // 同账号重试用尽,执行临时封禁并切换账号 - if failoverErr.RetryableOnSameAccount { - h.gatewayService.TempUnscheduleRetryableError(c.Request.Context(), account.ID, failoverErr) - } - - failedAccountIDs[account.ID] = struct{}{} - if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted) + case FailoverExhausted: + h.handleFailoverExhausted(c, fs.LastFailoverErr, account.Platform, streamStarted) + return + case FailoverCanceled: return } - switchCount++ - log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) - if account.Platform == service.PlatformAntigravity { - if !sleepFailoverDelay(c.Request.Context(), switchCount) { - return - } - } - continue } - // 错误响应已在Forward中处理,这里只记录日志 - log.Printf("Account %d: Forward request failed: %v", account.ID, err) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Error("gateway.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) return } + // RPM 计数递增(Forward 成功后) + // 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。 + // 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。 + if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 { + if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil { + reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + } + // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) - // 异步记录使用量(subscription已在函数开头获取) - go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, APIKey: currentAPIKey, User: currentAPIKey.User, - Account: usedAccount, + Account: account, Subscription: currentSubscription, - UserAgent: ua, + UserAgent: userAgent, IPAddress: clientIP, - ForceCacheBilling: fcb, + ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, }); err != nil { - log.Printf("Record usage failed: %v", err) + logger.L().With( + zap.String("component", "handler.gateway.messages"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", currentAPIKey.ID), + zap.Any("group_id", currentAPIKey.GroupID), + zap.String("model", reqModel), + zap.Int64("account_id", account.ID), + ).Error("gateway.record_usage_failed", zap.Error(err)) } - }(result, account, userAgent, clientIP, forceCacheBilling) + }) return } if !retryWithFallback { @@ -682,6 +775,17 @@ func (h *GatewayHandler) Models(c *gin.Context) { groupID = &apiKey.Group.ID platform = apiKey.Group.Platform } + if forcedPlatform, ok := middleware2.GetForcePlatformFromContext(c); ok && strings.TrimSpace(forcedPlatform) != "" { + platform = forcedPlatform + } + + if platform == service.PlatformSora { + c.JSON(http.StatusOK, gin.H{ + "object": "list", + "data": service.DefaultSoraModels(h.cfg), + }) + return + } // Get available models from account configurations (without platform filter) availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "") @@ -741,6 +845,10 @@ func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service // Usage handles getting account balance and usage statistics for CC Switch integration // GET /v1/usage +// +// Two modes: +// - quota_limited: API Key has quota or rate limits configured. Returns key-level limits/usage. +// - unrestricted: No key-level limits. Returns subscription or wallet balance info. func (h *GatewayHandler) Usage(c *gin.Context) { apiKey, ok := middleware2.GetAPIKeyFromContext(c) if !ok { @@ -754,54 +862,183 @@ func (h *GatewayHandler) Usage(c *gin.Context) { return } + ctx := c.Request.Context() + + // 解析可选的日期范围参数(用于 model_stats 查询) + startTime, endTime := h.parseUsageDateRange(c) + // Best-effort: 获取用量统计(按当前 API Key 过滤),失败不影响基础响应 - var usageData gin.H + usageData := h.buildUsageData(ctx, apiKey.ID) + + // Best-effort: 获取模型统计 + var modelStats any if h.usageService != nil { - dashStats, err := h.usageService.GetAPIKeyDashboardStats(c.Request.Context(), apiKey.ID) - if err == nil && dashStats != nil { - usageData = gin.H{ - "today": gin.H{ - "requests": dashStats.TodayRequests, - "input_tokens": dashStats.TodayInputTokens, - "output_tokens": dashStats.TodayOutputTokens, - "cache_creation_tokens": dashStats.TodayCacheCreationTokens, - "cache_read_tokens": dashStats.TodayCacheReadTokens, - "total_tokens": dashStats.TodayTokens, - "cost": dashStats.TodayCost, - "actual_cost": dashStats.TodayActualCost, - }, - "total": gin.H{ - "requests": dashStats.TotalRequests, - "input_tokens": dashStats.TotalInputTokens, - "output_tokens": dashStats.TotalOutputTokens, - "cache_creation_tokens": dashStats.TotalCacheCreationTokens, - "cache_read_tokens": dashStats.TotalCacheReadTokens, - "total_tokens": dashStats.TotalTokens, - "cost": dashStats.TotalCost, - "actual_cost": dashStats.TotalActualCost, - }, - "average_duration_ms": dashStats.AverageDurationMs, - "rpm": dashStats.Rpm, - "tpm": dashStats.Tpm, + if stats, err := h.usageService.GetAPIKeyModelStats(ctx, apiKey.ID, startTime, endTime); err == nil && len(stats) > 0 { + modelStats = stats + } + } + + // 判断模式: key 有总额度或速率限制 → quota_limited,否则 → unrestricted + isQuotaLimited := apiKey.Quota > 0 || apiKey.HasRateLimits() + + if isQuotaLimited { + h.usageQuotaLimited(c, ctx, apiKey, usageData, modelStats) + return + } + + h.usageUnrestricted(c, ctx, apiKey, subject, usageData, modelStats) +} + +// parseUsageDateRange 解析 start_date / end_date query params,默认返回近 30 天范围 +func (h *GatewayHandler) parseUsageDateRange(c *gin.Context) (time.Time, time.Time) { + now := timezone.Now() + endTime := now + startTime := now.AddDate(0, 0, -30) + + if s := c.Query("start_date"); s != "" { + if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil { + startTime = t + } + } + if s := c.Query("end_date"); s != "" { + if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil { + endTime = t.Add(24*time.Hour - time.Second) // end of day + } + } + return startTime, endTime +} + +// buildUsageData 构建 today/total 用量摘要 +func (h *GatewayHandler) buildUsageData(ctx context.Context, apiKeyID int64) gin.H { + if h.usageService == nil { + return nil + } + dashStats, err := h.usageService.GetAPIKeyDashboardStats(ctx, apiKeyID) + if err != nil || dashStats == nil { + return nil + } + return gin.H{ + "today": gin.H{ + "requests": dashStats.TodayRequests, + "input_tokens": dashStats.TodayInputTokens, + "output_tokens": dashStats.TodayOutputTokens, + "cache_creation_tokens": dashStats.TodayCacheCreationTokens, + "cache_read_tokens": dashStats.TodayCacheReadTokens, + "total_tokens": dashStats.TodayTokens, + "cost": dashStats.TodayCost, + "actual_cost": dashStats.TodayActualCost, + }, + "total": gin.H{ + "requests": dashStats.TotalRequests, + "input_tokens": dashStats.TotalInputTokens, + "output_tokens": dashStats.TotalOutputTokens, + "cache_creation_tokens": dashStats.TotalCacheCreationTokens, + "cache_read_tokens": dashStats.TotalCacheReadTokens, + "total_tokens": dashStats.TotalTokens, + "cost": dashStats.TotalCost, + "actual_cost": dashStats.TotalActualCost, + }, + "average_duration_ms": dashStats.AverageDurationMs, + "rpm": dashStats.Rpm, + "tpm": dashStats.Tpm, + } +} + +// usageQuotaLimited 处理 quota_limited 模式的响应 +func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, apiKey *service.APIKey, usageData gin.H, modelStats any) { + resp := gin.H{ + "mode": "quota_limited", + "isValid": apiKey.Status == service.StatusAPIKeyActive || apiKey.Status == service.StatusAPIKeyQuotaExhausted || apiKey.Status == service.StatusAPIKeyExpired, + "status": apiKey.Status, + } + + // 总额度信息 + if apiKey.Quota > 0 { + remaining := apiKey.GetQuotaRemaining() + resp["quota"] = gin.H{ + "limit": apiKey.Quota, + "used": apiKey.QuotaUsed, + "remaining": remaining, + "unit": "USD", + } + resp["remaining"] = remaining + resp["unit"] = "USD" + } + + // 速率限制信息(从 DB 获取实时用量) + if apiKey.HasRateLimits() && h.apiKeyService != nil { + rateLimitData, err := h.apiKeyService.GetRateLimitData(ctx, apiKey.ID) + if err == nil && rateLimitData != nil { + var rateLimits []gin.H + if apiKey.RateLimit5h > 0 { + used := rateLimitData.Usage5h + rateLimits = append(rateLimits, gin.H{ + "window": "5h", + "limit": apiKey.RateLimit5h, + "used": used, + "remaining": max(0, apiKey.RateLimit5h-used), + "window_start": rateLimitData.Window5hStart, + }) + } + if apiKey.RateLimit1d > 0 { + used := rateLimitData.Usage1d + rateLimits = append(rateLimits, gin.H{ + "window": "1d", + "limit": apiKey.RateLimit1d, + "used": used, + "remaining": max(0, apiKey.RateLimit1d-used), + "window_start": rateLimitData.Window1dStart, + }) + } + if apiKey.RateLimit7d > 0 { + used := rateLimitData.Usage7d + rateLimits = append(rateLimits, gin.H{ + "window": "7d", + "limit": apiKey.RateLimit7d, + "used": used, + "remaining": max(0, apiKey.RateLimit7d-used), + "window_start": rateLimitData.Window7dStart, + }) + } + if len(rateLimits) > 0 { + resp["rate_limits"] = rateLimits } } } - // 订阅模式:返回订阅限额信息 + 用量统计 + // 过期时间 + if apiKey.ExpiresAt != nil { + resp["expires_at"] = apiKey.ExpiresAt + resp["days_until_expiry"] = apiKey.GetDaysUntilExpiry() + } + + if usageData != nil { + resp["usage"] = usageData + } + if modelStats != nil { + resp["model_stats"] = modelStats + } + + c.JSON(http.StatusOK, resp) +} + +// usageUnrestricted 处理 unrestricted 模式的响应(向后兼容) +func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, usageData gin.H, modelStats any) { + // 订阅模式 if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() { - subscription, ok := middleware2.GetSubscriptionFromContext(c) - if !ok { - h.errorResponse(c, http.StatusForbidden, "subscription_error", "No active subscription") - return + resp := gin.H{ + "mode": "unrestricted", + "isValid": true, + "planName": apiKey.Group.Name, + "unit": "USD", } - remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription) - resp := gin.H{ - "isValid": true, - "planName": apiKey.Group.Name, - "remaining": remaining, - "unit": "USD", - "subscription": gin.H{ + // 订阅信息可能不在 context 中(/v1/usage 路径跳过了中间件的计费检查) + subscription, ok := middleware2.GetSubscriptionFromContext(c) + if ok { + remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription) + resp["remaining"] = remaining + resp["subscription"] = gin.H{ "daily_usage_usd": subscription.DailyUsageUSD, "weekly_usage_usd": subscription.WeeklyUsageUSD, "monthly_usage_usd": subscription.MonthlyUsageUSD, @@ -809,23 +1046,28 @@ func (h *GatewayHandler) Usage(c *gin.Context) { "weekly_limit_usd": apiKey.Group.WeeklyLimitUSD, "monthly_limit_usd": apiKey.Group.MonthlyLimitUSD, "expires_at": subscription.ExpiresAt, - }, + } } + if usageData != nil { resp["usage"] = usageData } + if modelStats != nil { + resp["model_stats"] = modelStats + } c.JSON(http.StatusOK, resp) return } - // 余额模式:返回钱包余额 + 用量统计 - latestUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID) + // 余额模式 + latestUser, err := h.userService.GetByID(ctx, subject.UserID) if err != nil { h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info") return } resp := gin.H{ + "mode": "unrestricted", "isValid": true, "planName": "钱包余额", "remaining": latestUser.Balance, @@ -835,6 +1077,9 @@ func (h *GatewayHandler) Usage(c *gin.Context) { if usageData != nil { resp["usage"] = usageData } + if modelStats != nil { + resp["model_stats"] = modelStats + } c.JSON(http.StatusOK, resp) } @@ -893,65 +1138,6 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) } -// needForceCacheBilling 判断 failover 时是否需要强制缓存计费 -// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费 -func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool { - return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling) -} - -const ( - // maxSameAccountRetries 同账号重试次数上限(针对 RetryableOnSameAccount 错误) - maxSameAccountRetries = 2 - // sameAccountRetryDelay 同账号重试间隔 - sameAccountRetryDelay = 500 * time.Millisecond -) - -// sleepSameAccountRetryDelay 同账号重试固定延时,返回 false 表示 context 已取消。 -func sleepSameAccountRetryDelay(ctx context.Context) bool { - select { - case <-ctx.Done(): - return false - case <-time.After(sameAccountRetryDelay): - return true - } -} - -// sleepFailoverDelay 账号切换线性递增延时:第1次0s、第2次1s、第3次2s… -// 返回 false 表示 context 已取消。 -func sleepFailoverDelay(ctx context.Context, switchCount int) bool { - delay := time.Duration(switchCount-1) * time.Second - if delay <= 0 { - return true - } - select { - case <-ctx.Done(): - return false - case <-time.After(delay): - return true - } -} - -// sleepAntigravitySingleAccountBackoff Antigravity 平台单账号分组的 503 退避重试延时。 -// 当分组内只有一个可用账号且上游返回 503(MODEL_CAPACITY_EXHAUSTED)时使用, -// 采用短固定延时策略。Service 层在 SingleAccountRetry 模式下已经做了充分的原地重试 -// (最多 3 次、总等待 30s),所以 Handler 层的退避只需短暂等待即可。 -// 返回 false 表示 context 已取消。 -func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) bool { - // 固定短延时:2s - // Service 层已经在原地等待了足够长的时间(retryDelay × 重试次数), - // Handler 层只需短暂间隔后重新进入 Service 层即可。 - const delay = 2 * time.Second - - log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)", delay, retryCount) - - select { - case <-ctx.Done(): - return false - case <-time.After(delay): - return true - } -} - func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) { statusCode := failoverErr.StatusCode responseBody := failoverErr.ResponseBody @@ -1014,20 +1200,8 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e // Stream already started, send error as SSE event then close flusher, ok := c.Writer.(http.Flusher) if ok { - // Send error event in SSE format with proper JSON marshaling - errorData := map[string]any{ - "type": "error", - "error": map[string]string{ - "type": errType, - "message": message, - }, - } - jsonBytes, err := json.Marshal(errorData) - if err != nil { - _ = c.Error(err) - return - } - errorEvent := fmt.Sprintf("data: %s\n\n", string(jsonBytes)) + // SSE 错误事件固定 schema,使用 Quote 直拼可避免额外 Marshal 分配。 + errorEvent := `data: {"type":"error","error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n" if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { _ = c.Error(err) } @@ -1040,6 +1214,50 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e h.errorResponse(c, status, errType, message) } +// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。 +func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool { + if c == nil || c.Writer == nil || c.Writer.Written() { + return false + } + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted) + return true +} + +// checkClaudeCodeVersion 检查 Claude Code 客户端版本是否满足最低要求 +// 仅对已识别的 Claude Code 客户端执行,count_tokens 路径除外 +func (h *GatewayHandler) checkClaudeCodeVersion(c *gin.Context) bool { + ctx := c.Request.Context() + if !service.IsClaudeCodeClient(ctx) { + return true + } + + // 排除 count_tokens 子路径 + if strings.HasSuffix(c.Request.URL.Path, "/count_tokens") { + return true + } + + minVersion := h.settingService.GetMinClaudeCodeVersion(ctx) + if minVersion == "" { + return true // 未设置,不检查 + } + + clientVersion := service.GetClaudeCodeVersion(ctx) + if clientVersion == "" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", + "Unable to determine Claude Code version. Please update Claude Code: npm update -g @anthropic-ai/claude-code") + return false + } + + if service.CompareVersions(clientVersion, minVersion) < 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", + fmt.Sprintf("Your Claude Code version (%s) is below the minimum required version (%s). Please update: npm update -g @anthropic-ai/claude-code", + clientVersion, minVersion)) + return false + } + + return true +} + // errorResponse 返回Claude API格式的错误响应 func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { c.JSON(status, gin.H{ @@ -1067,9 +1285,16 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") return } + reqLog := requestLogger( + c, + "handler.gateway.count_tokens", + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + defer h.maybeLogCompatibilityFallbackMetrics(reqLog) // 读取请求体 - body, err := io.ReadAll(c.Request.Body) + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) @@ -1084,9 +1309,6 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { return } - // 检查是否为 Claude Code 客户端,设置到 context 中 - SetClaudeCodeClientContext(c, body) - setOpsRequestContext(c, "", false, body) parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic) @@ -1094,8 +1316,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return } + // count_tokens 走 messages 严格校验时,复用已解析请求,避免二次反序列化。 + SetClaudeCodeClientContext(c, body, parsedReq) + reqLog = reqLog.With(zap.String("model", parsedReq.Model), zap.Bool("stream", parsedReq.Stream)) // 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用 - c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled)) + c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled())) // 验证 model 必填 if parsedReq.Model == "" { @@ -1127,14 +1352,15 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { // 选择支持该模型的账号 account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model) if err != nil { - h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) + reqLog.Warn("gateway.count_tokens_select_account_failed", zap.Error(err)) + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable") return } - setOpsSelectedAccount(c, account.ID) + setOpsSelectedAccount(c, account.ID, account.Platform) // 转发请求(不记录使用量) if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil { - log.Printf("Forward count_tokens request failed: %v", err) + reqLog.Error("gateway.count_tokens_forward_failed", zap.Int64("account_id", account.ID), zap.Error(err)) // 错误响应已在 ForwardCountTokens 中处理 return } @@ -1258,24 +1484,8 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce textDeltas = []string{"New", " Conversation"} } - // Build message_start event with proper JSON marshaling - messageStart := map[string]any{ - "type": "message_start", - "message": map[string]any{ - "id": msgID, - "type": "message", - "role": "assistant", - "model": model, - "content": []any{}, - "stop_reason": nil, - "stop_sequence": nil, - "usage": map[string]int{ - "input_tokens": 10, - "output_tokens": 0, - }, - }, - } - messageStartJSON, _ := json.Marshal(messageStart) + // Build message_start event with fixed schema. + messageStartJSON := `{"type":"message_start","message":{"id":` + strconv.Quote(msgID) + `,"type":"message","role":"assistant","model":` + strconv.Quote(model) + `,"content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":0}}}` // Build events events := []string{ @@ -1285,31 +1495,12 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce // Add text deltas for _, text := range textDeltas { - delta := map[string]any{ - "type": "content_block_delta", - "index": 0, - "delta": map[string]string{ - "type": "text_delta", - "text": text, - }, - } - deltaJSON, _ := json.Marshal(delta) + deltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":` + strconv.Quote(text) + `}}` events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON)) } // Add final events - messageDelta := map[string]any{ - "type": "message_delta", - "delta": map[string]any{ - "stop_reason": "end_turn", - "stop_sequence": nil, - }, - "usage": map[string]int{ - "input_tokens": 10, - "output_tokens": outputTokens, - }, - } - messageDeltaJSON, _ := json.Marshal(messageDelta) + messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":10,"output_tokens":` + strconv.Itoa(outputTokens) + `}}` events = append(events, `event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`, @@ -1396,9 +1587,92 @@ func billingErrorDetails(err error) (status int, code, message string) { } return http.StatusServiceUnavailable, "billing_service_error", msg } + if errors.Is(err, service.ErrAPIKeyRateLimit5hExceeded) { + msg := pkgerrors.Message(err) + return http.StatusTooManyRequests, "rate_limit_exceeded", msg + } + if errors.Is(err, service.ErrAPIKeyRateLimit1dExceeded) { + msg := pkgerrors.Message(err) + return http.StatusTooManyRequests, "rate_limit_exceeded", msg + } + if errors.Is(err, service.ErrAPIKeyRateLimit7dExceeded) { + msg := pkgerrors.Message(err) + return http.StatusTooManyRequests, "rate_limit_exceeded", msg + } msg := pkgerrors.Message(err) if msg == "" { - msg = err.Error() + logger.L().With( + zap.String("component", "handler.gateway.billing"), + zap.Error(err), + ).Warn("gateway.billing_error_missing_message") + msg = "Billing error" } return http.StatusForbidden, "billing_error", msg } + +func (h *GatewayHandler) metadataBridgeEnabled() bool { + if h == nil || h.cfg == nil { + return true + } + return h.cfg.Gateway.OpenAIWS.MetadataBridgeEnabled +} + +func (h *GatewayHandler) maybeLogCompatibilityFallbackMetrics(reqLog *zap.Logger) { + if reqLog == nil { + return + } + if gatewayCompatibilityMetricsLogCounter.Add(1)%gatewayCompatibilityMetricsLogInterval != 0 { + return + } + metrics := service.SnapshotOpenAICompatibilityFallbackMetrics() + reqLog.Info("gateway.compatibility_fallback_metrics", + zap.Int64("session_hash_legacy_read_fallback_total", metrics.SessionHashLegacyReadFallbackTotal), + zap.Int64("session_hash_legacy_read_fallback_hit", metrics.SessionHashLegacyReadFallbackHit), + zap.Int64("session_hash_legacy_dual_write_total", metrics.SessionHashLegacyDualWriteTotal), + zap.Float64("session_hash_legacy_read_hit_rate", metrics.SessionHashLegacyReadHitRate), + zap.Int64("metadata_legacy_fallback_total", metrics.MetadataLegacyFallbackTotal), + ) +} + +func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { + if task == nil { + return + } + if h.usageRecordWorkerPool != nil { + h.usageRecordWorkerPool.Submit(task) + return + } + // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + defer func() { + if recovered := recover(); recovered != nil { + logger.L().With( + zap.String("component", "handler.gateway.messages"), + zap.Any("panic", recovered), + ).Error("gateway.usage_record_task_panic_recovered") + } + }() + task(ctx) +} + +// getUserMsgQueueMode 获取当前请求的 UMQ 模式 +// 返回 "serialize" | "throttle" | "" +func (h *GatewayHandler) getUserMsgQueueMode(account *service.Account, parsed *service.ParsedRequest) string { + if h.userMsgQueueHelper == nil { + return "" + } + // 仅适用于 Anthropic OAuth/SetupToken 账号 + if !account.IsAnthropicOAuthOrSetupToken() { + return "" + } + if !service.IsRealUserMessage(parsed) { + return "" + } + // 账号级模式优先,fallback 到全局配置 + mode := account.GetUserMsgQueueMode() + if mode == "" { + mode = h.cfg.Gateway.UserMessageQueue.GetEffectiveMode() + } + return mode +} diff --git a/backend/internal/handler/gateway_handler_error_fallback_test.go b/backend/internal/handler/gateway_handler_error_fallback_test.go new file mode 100644 index 00000000..4fce5ec1 --- /dev/null +++ b/backend/internal/handler/gateway_handler_error_fallback_test.go @@ -0,0 +1,49 @@ +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGatewayEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &GatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.True(t, wrote) + require.Equal(t, http.StatusBadGateway, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + assert.Equal(t, "error", parsed["type"]) + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errorObj["type"]) + assert.Equal(t, "Upstream request failed", errorObj["message"]) +} + +func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.String(http.StatusTeapot, "already written") + + h := &GatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.False(t, wrote) + require.Equal(t, http.StatusTeapot, w.Code) + assert.Equal(t, "already written", w.Body.String()) +} diff --git a/backend/internal/handler/gateway_handler_single_account_retry_test.go b/backend/internal/handler/gateway_handler_single_account_retry_test.go deleted file mode 100644 index 96aa14c6..00000000 --- a/backend/internal/handler/gateway_handler_single_account_retry_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package handler - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -// --------------------------------------------------------------------------- -// sleepAntigravitySingleAccountBackoff 测试 -// --------------------------------------------------------------------------- - -func TestSleepAntigravitySingleAccountBackoff_ReturnsTrue(t *testing.T) { - ctx := context.Background() - start := time.Now() - ok := sleepAntigravitySingleAccountBackoff(ctx, 1) - elapsed := time.Since(start) - - require.True(t, ok, "should return true when context is not canceled") - // 固定延迟 2s - require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond, "should wait approximately 2s") - require.Less(t, elapsed, 5*time.Second, "should not wait too long") -} - -func TestSleepAntigravitySingleAccountBackoff_ContextCanceled(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() // 立即取消 - - start := time.Now() - ok := sleepAntigravitySingleAccountBackoff(ctx, 1) - elapsed := time.Since(start) - - require.False(t, ok, "should return false when context is canceled") - require.Less(t, elapsed, 500*time.Millisecond, "should return immediately on cancel") -} - -func TestSleepAntigravitySingleAccountBackoff_FixedDelay(t *testing.T) { - // 验证不同 retryCount 都使用固定 2s 延迟 - ctx := context.Background() - - start := time.Now() - ok := sleepAntigravitySingleAccountBackoff(ctx, 5) - elapsed := time.Since(start) - - require.True(t, ok) - // 即使 retryCount=5,延迟仍然是固定的 2s - require.GreaterOrEqual(t, elapsed, 1500*time.Millisecond) - require.Less(t, elapsed, 5*time.Second) -} diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go new file mode 100644 index 00000000..c07c568d --- /dev/null +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -0,0 +1,348 @@ +//go:build unit + +package handler + +import ( + "bytes" + "context" + "encoding/json" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + middleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// 目标:严格验证“antigravity 账号通过 /v1/messages 提供 Claude 服务时”, +// 当账号 credentials.intercept_warmup_requests=true 且请求为 Warmup 时, +// 后端会在转发上游前直接拦截并返回 mock 响应(不依赖上游)。 + +type fakeSchedulerCache struct { + accounts []*service.Account +} + +func (f *fakeSchedulerCache) GetSnapshot(_ context.Context, _ service.SchedulerBucket) ([]*service.Account, bool, error) { + return f.accounts, true, nil +} +func (f *fakeSchedulerCache) SetSnapshot(_ context.Context, _ service.SchedulerBucket, _ []service.Account) error { + return nil +} +func (f *fakeSchedulerCache) GetAccount(_ context.Context, _ int64) (*service.Account, error) { + return nil, nil +} +func (f *fakeSchedulerCache) SetAccount(_ context.Context, _ *service.Account) error { return nil } +func (f *fakeSchedulerCache) DeleteAccount(_ context.Context, _ int64) error { return nil } +func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time.Time) error { + return nil +} +func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) { + return true, nil +} +func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) { + return nil, nil +} +func (f *fakeSchedulerCache) GetOutboxWatermark(_ context.Context) (int64, error) { return 0, nil } +func (f *fakeSchedulerCache) SetOutboxWatermark(_ context.Context, _ int64) error { return nil } + +type fakeGroupRepo struct { + group *service.Group +} + +func (f *fakeGroupRepo) Create(context.Context, *service.Group) error { return nil } +func (f *fakeGroupRepo) GetByID(context.Context, int64) (*service.Group, error) { + return f.group, nil +} +func (f *fakeGroupRepo) GetByIDLite(context.Context, int64) (*service.Group, error) { + return f.group, nil +} +func (f *fakeGroupRepo) Update(context.Context, *service.Group) error { return nil } +func (f *fakeGroupRepo) Delete(context.Context, int64) error { return nil } +func (f *fakeGroupRepo) DeleteCascade(context.Context, int64) ([]int64, error) { return nil, nil } +func (f *fakeGroupRepo) List(context.Context, pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (f *fakeGroupRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]service.Group, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (f *fakeGroupRepo) ListActive(context.Context) ([]service.Group, error) { return nil, nil } +func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service.Group, error) { + return nil, nil +} +func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil } +func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil } +func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { + return 0, nil +} +func (f *fakeGroupRepo) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) { + return nil, nil +} +func (f *fakeGroupRepo) BindAccountsToGroup(context.Context, int64, []int64) error { return nil } +func (f *fakeGroupRepo) UpdateSortOrders(context.Context, []service.GroupSortOrderUpdate) error { + return nil +} + +type fakeConcurrencyCache struct{} + +func (f *fakeConcurrencyCache) AcquireAccountSlot(context.Context, int64, int, string) (bool, error) { + return true, nil +} +func (f *fakeConcurrencyCache) ReleaseAccountSlot(context.Context, int64, string) error { return nil } +func (f *fakeConcurrencyCache) GetAccountConcurrency(context.Context, int64) (int, error) { + return 0, nil +} +func (f *fakeConcurrencyCache) IncrementAccountWaitCount(context.Context, int64, int) (bool, error) { + return true, nil +} +func (f *fakeConcurrencyCache) DecrementAccountWaitCount(context.Context, int64) error { return nil } +func (f *fakeConcurrencyCache) GetAccountWaitingCount(context.Context, int64) (int, error) { + return 0, nil +} +func (f *fakeConcurrencyCache) AcquireUserSlot(context.Context, int64, int, string) (bool, error) { + return true, nil +} +func (f *fakeConcurrencyCache) ReleaseUserSlot(context.Context, int64, string) error { return nil } +func (f *fakeConcurrencyCache) GetUserConcurrency(context.Context, int64) (int, error) { return 0, nil } +func (f *fakeConcurrencyCache) IncrementWaitCount(context.Context, int64, int) (bool, error) { + return true, nil +} +func (f *fakeConcurrencyCache) DecrementWaitCount(context.Context, int64) error { return nil } +func (f *fakeConcurrencyCache) GetAccountsLoadBatch(context.Context, []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { + return map[int64]*service.AccountLoadInfo{}, nil +} +func (f *fakeConcurrencyCache) GetUsersLoadBatch(context.Context, []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) { + return map[int64]*service.UserLoadInfo{}, nil +} +func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) { + result := make(map[int64]int, len(accountIDs)) + for _, id := range accountIDs { + result[id] = 0 + } + return result, nil +} +func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil } + +func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) { + t.Helper() + + schedulerCache := &fakeSchedulerCache{accounts: accounts} + schedulerSnapshot := service.NewSchedulerSnapshotService(schedulerCache, nil, nil, nil, nil) + + gwSvc := service.NewGatewayService( + nil, // accountRepo (not used: scheduler snapshot hit) + &fakeGroupRepo{group: group}, + nil, // usageLogRepo + nil, // userRepo + nil, // userSubRepo + nil, // userGroupRateRepo + nil, // cache (disable sticky) + nil, // cfg + schedulerSnapshot, + nil, // concurrencyService (disable load-aware; tryAcquire always acquired) + nil, // billingService + nil, // rateLimitService + nil, // billingCacheService + nil, // identityService + nil, // httpUpstream + nil, // deferredService + nil, // claudeTokenProvider + nil, // sessionLimitCache + nil, // rpmCache + nil, // digestStore + ) + + // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 + cfg := &config.Config{RunMode: config.RunModeSimple} + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg) + + concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{}) + concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0) + + h := &GatewayHandler{ + gatewayService: gwSvc, + billingCacheService: billingCacheSvc, + concurrencyHelper: concurrencyHelper, + // 这些字段对本测试不敏感,保持较小即可 + maxAccountSwitches: 1, + maxAccountSwitchesGemini: 1, + } + + cleanup := func() { + billingCacheSvc.Stop() + } + return h, cleanup +} + +func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_MixedSchedulingV1(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(2001) + accountID := int64(1001) + + group := &service.Group{ + ID: groupID, + Hydrated: true, + Platform: service.PlatformAnthropic, // /v1/messages(Claude兼容)入口 + Status: service.StatusActive, + } + + account := &service.Account{ + ID: accountID, + Name: "ag-1", + Platform: service.PlatformAntigravity, + Type: service.AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "tok_xxx", + "intercept_warmup_requests": true, + }, + Extra: map[string]any{ + "mixed_scheduling": true, // 关键:允许被 anthropic 分组混合调度选中 + }, + Concurrency: 1, + Priority: 1, + Status: service.StatusActive, + Schedulable: true, + AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}}, + } + + h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account}) + defer cleanup() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + body := []byte(`{ + "model": "claude-sonnet-4-5", + "max_tokens": 256, + "messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}] + }`) + req := httptest.NewRequest("POST", "/v1/messages", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, group)) + c.Request = req + + apiKey := &service.APIKey{ + ID: 3001, + UserID: 4001, + GroupID: &groupID, + Status: service.StatusActive, + User: &service.User{ + ID: 4001, + Concurrency: 10, + Balance: 100, + }, + Group: group, + } + + c.Set(string(middleware.ContextKeyAPIKey), apiKey) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10}) + + h.Messages(c) + + require.Equal(t, 200, rec.Code) + + // 断言:确实选中了 antigravity 账号(不是纯函数测试,而是从 Handler 里验证调度结果) + selected, ok := c.Get(opsAccountIDKey) + require.True(t, ok) + require.Equal(t, accountID, selected) + + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, "msg_mock_warmup", resp["id"]) + require.Equal(t, "claude-sonnet-4-5", resp["model"]) + + content, ok := resp["content"].([]any) + require.True(t, ok) + require.Len(t, content, 1) + first, ok := content[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "New Conversation", first["text"]) +} + +func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_ForcePlatform(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(2002) + accountID := int64(1002) + + group := &service.Group{ + ID: groupID, + Hydrated: true, + Platform: service.PlatformAntigravity, + Status: service.StatusActive, + } + + account := &service.Account{ + ID: accountID, + Name: "ag-2", + Platform: service.PlatformAntigravity, + Type: service.AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "tok_xxx", + "intercept_warmup_requests": true, + }, + Concurrency: 1, + Priority: 1, + Status: service.StatusActive, + Schedulable: true, + AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}}, + } + + h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account}) + defer cleanup() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + body := []byte(`{ + "model": "claude-sonnet-4-5", + "max_tokens": 256, + "messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}] + }`) + req := httptest.NewRequest("POST", "/antigravity/v1/messages", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + // 模拟 routes/gateway.go 里的 ForcePlatform 中间件效果: + // - 写入 request.Context(Service读取) + // - 写入 gin.Context(Handler快速读取) + ctx := context.WithValue(req.Context(), ctxkey.Group, group) + ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformAntigravity) + req = req.WithContext(ctx) + c.Request = req + c.Set(string(middleware.ContextKeyForcePlatform), service.PlatformAntigravity) + + apiKey := &service.APIKey{ + ID: 3002, + UserID: 4002, + GroupID: &groupID, + Status: service.StatusActive, + User: &service.User{ + ID: 4002, + Concurrency: 10, + Balance: 100, + }, + Group: group, + } + + c.Set(string(middleware.ContextKeyAPIKey), apiKey) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10}) + + h.Messages(c) + + require.Equal(t, 200, rec.Code) + + selected, ok := c.Get(opsAccountIDKey) + require.True(t, ok) + require.Equal(t, accountID, selected) + + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, "msg_mock_warmup", resp["id"]) + require.Equal(t, "claude-sonnet-4-5", resp["model"]) +} diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 0393f954..09e6c09b 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -4,8 +4,9 @@ import ( "context" "encoding/json" "fmt" - "math/rand" + "math/rand/v2" "net/http" + "strings" "sync" "time" @@ -17,23 +18,91 @@ import ( // claudeCodeValidator is a singleton validator for Claude Code client detection var claudeCodeValidator = service.NewClaudeCodeValidator() +const claudeCodeParsedRequestContextKey = "claude_code_parsed_request" + // SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中 // 返回更新后的 context -func SetClaudeCodeClientContext(c *gin.Context, body []byte) { - // 解析请求体为 map - var bodyMap map[string]any - if len(body) > 0 { - _ = json.Unmarshal(body, &bodyMap) +func SetClaudeCodeClientContext(c *gin.Context, body []byte, parsedReq *service.ParsedRequest) { + if c == nil || c.Request == nil { + return + } + if parsedReq != nil { + c.Set(claudeCodeParsedRequestContextKey, parsedReq) } - // 验证是否为 Claude Code 客户端 - isClaudeCode := claudeCodeValidator.Validate(c.Request, bodyMap) + ua := c.GetHeader("User-Agent") + // Fast path:非 Claude CLI UA 直接判定 false,避免热路径二次 JSON 反序列化。 + if !claudeCodeValidator.ValidateUserAgent(ua) { + ctx := service.SetClaudeCodeClient(c.Request.Context(), false) + c.Request = c.Request.WithContext(ctx) + return + } + + isClaudeCode := false + if !strings.Contains(c.Request.URL.Path, "messages") { + // 与 Validate 行为一致:非 messages 路径 UA 命中即可视为 Claude Code 客户端。 + isClaudeCode = true + } else { + // 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。 + bodyMap := claudeCodeBodyMapFromParsedRequest(parsedReq) + if bodyMap == nil { + bodyMap = claudeCodeBodyMapFromContextCache(c) + } + if bodyMap == nil && len(body) > 0 { + _ = json.Unmarshal(body, &bodyMap) + } + isClaudeCode = claudeCodeValidator.Validate(c.Request, bodyMap) + } // 更新 request context ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode) + + // 仅在确认为 Claude Code 客户端时提取版本号写入 context + if isClaudeCode { + if version := claudeCodeValidator.ExtractVersion(ua); version != "" { + ctx = service.SetClaudeCodeVersion(ctx, version) + } + } + c.Request = c.Request.WithContext(ctx) } +func claudeCodeBodyMapFromParsedRequest(parsedReq *service.ParsedRequest) map[string]any { + if parsedReq == nil { + return nil + } + bodyMap := map[string]any{ + "model": parsedReq.Model, + } + if parsedReq.System != nil || parsedReq.HasSystem { + bodyMap["system"] = parsedReq.System + } + if parsedReq.MetadataUserID != "" { + bodyMap["metadata"] = map[string]any{"user_id": parsedReq.MetadataUserID} + } + return bodyMap +} + +func claudeCodeBodyMapFromContextCache(c *gin.Context) map[string]any { + if c == nil { + return nil + } + if cached, ok := c.Get(service.OpenAIParsedRequestBodyKey); ok { + if bodyMap, ok := cached.(map[string]any); ok { + return bodyMap + } + } + if cached, ok := c.Get(claudeCodeParsedRequestContextKey); ok { + switch v := cached.(type) { + case *service.ParsedRequest: + return claudeCodeBodyMapFromParsedRequest(v) + case service.ParsedRequest: + return claudeCodeBodyMapFromParsedRequest(&v) + } + } + return nil +} + // 并发槽位等待相关常量 // // 性能优化说明: @@ -104,31 +173,24 @@ func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFo // wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation. // 用于避免客户端断开或上游超时导致的并发槽位泄漏。 -// 修复:添加 quit channel 确保 goroutine 及时退出,避免泄露 +// 优化:基于 context.AfterFunc 注册回调,避免每请求额外守护 goroutine。 func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() { if releaseFunc == nil { return nil } var once sync.Once - quit := make(chan struct{}) + var stop func() bool release := func() { once.Do(func() { + if stop != nil { + _ = stop() + } releaseFunc() - close(quit) // 通知监听 goroutine 退出 }) } - go func() { - select { - case <-ctx.Done(): - // Context 取消时释放资源 - release() - case <-quit: - // 正常释放已完成,goroutine 退出 - return - } - }() + stop = context.AfterFunc(ctx, release) return release } @@ -153,6 +215,32 @@ func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accou h.concurrencyService.DecrementAccountWaitCount(ctx, accountID) } +// TryAcquireUserSlot 尝试立即获取用户并发槽位。 +// 返回值: (releaseFunc, acquired, error) +func (h *ConcurrencyHelper) TryAcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (func(), bool, error) { + result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency) + if err != nil { + return nil, false, err + } + if !result.Acquired { + return nil, false, nil + } + return result.ReleaseFunc, true, nil +} + +// TryAcquireAccountSlot 尝试立即获取账号并发槽位。 +// 返回值: (releaseFunc, acquired, error) +func (h *ConcurrencyHelper) TryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (func(), bool, error) { + result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) + if err != nil { + return nil, false, err + } + if !result.Acquired { + return nil, false, nil + } + return result.ReleaseFunc, true, nil +} + // AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary. // For streaming requests, sends ping events during the wait. // streamStarted is updated if streaming response has begun. @@ -160,13 +248,13 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64 ctx := c.Request.Context() // Try to acquire immediately - result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency) + releaseFunc, acquired, err := h.TryAcquireUserSlot(ctx, userID, maxConcurrency) if err != nil { return nil, err } - if result.Acquired { - return result.ReleaseFunc, nil + if acquired { + return releaseFunc, nil } // Need to wait - handle streaming ping if needed @@ -180,13 +268,13 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID ctx := c.Request.Context() // Try to acquire immediately - result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) + releaseFunc, acquired, err := h.TryAcquireAccountSlot(ctx, accountID, maxConcurrency) if err != nil { return nil, err } - if result.Acquired { - return result.ReleaseFunc, nil + if acquired { + return releaseFunc, nil } // Need to wait - handle streaming ping if needed @@ -196,27 +284,29 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID // waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests. // streamStarted pointer is updated when streaming begins (for proper error handling by caller). func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) { - return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted) + return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted, false) } // waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout. -func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) { +func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool, tryImmediate bool) (func(), error) { ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) defer cancel() - // Try immediate acquire first (avoid unnecessary wait) - var result *service.AcquireResult - var err error - if slotType == "user" { - result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency) - } else { - result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency) + acquireSlot := func() (*service.AcquireResult, error) { + if slotType == "user" { + return h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency) + } + return h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency) } - if err != nil { - return nil, err - } - if result.Acquired { - return result.ReleaseFunc, nil + + if tryImmediate { + result, err := acquireSlot() + if err != nil { + return nil, err + } + if result.Acquired { + return result.ReleaseFunc, nil + } } // Determine if ping is needed (streaming + ping format defined) @@ -242,7 +332,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType backoff := initialBackoff timer := time.NewTimer(backoff) defer timer.Stop() - rng := rand.New(rand.NewSource(time.Now().UnixNano())) for { select { @@ -268,15 +357,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType case <-timer.C: // Try to acquire slot - var result *service.AcquireResult - var err error - - if slotType == "user" { - result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency) - } else { - result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency) - } - + result, err := acquireSlot() if err != nil { return nil, err } @@ -284,7 +365,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType if result.Acquired { return result.ReleaseFunc, nil } - backoff = nextBackoff(backoff, rng) + backoff = nextBackoff(backoff) timer.Reset(backoff) } } @@ -292,26 +373,22 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType // AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping). func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) { - return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted) + return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted, true) } // nextBackoff 计算下一次退避时间 // 性能优化:使用指数退避 + 随机抖动,避免惊群效应 // current: 当前退避时间 -// rng: 随机数生成器(可为 nil,此时不添加抖动) // 返回值:下一次退避时间(100ms ~ 2s 之间) -func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration { +func nextBackoff(current time.Duration) time.Duration { // 指数退避:当前时间 * 1.5 next := time.Duration(float64(current) * backoffMultiplier) if next > maxBackoff { next = maxBackoff } - if rng == nil { - return next - } // 添加 ±20% 的随机抖动(jitter 范围 0.8 ~ 1.2) // 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis - jitter := 0.8 + rng.Float64()*0.4 + jitter := 0.8 + rand.Float64()*0.4 jittered := time.Duration(float64(next) * jitter) if jittered < initialBackoff { return initialBackoff diff --git a/backend/internal/handler/gateway_helper_backoff_test.go b/backend/internal/handler/gateway_helper_backoff_test.go new file mode 100644 index 00000000..a5056bbb --- /dev/null +++ b/backend/internal/handler/gateway_helper_backoff_test.go @@ -0,0 +1,106 @@ +package handler + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- Task 6.2 验证: math/rand/v2 迁移后 nextBackoff 行为正确 --- + +func TestNextBackoff_ExponentialGrowth(t *testing.T) { + // 验证退避时间指数增长(乘数 1.5) + // 由于有随机抖动(±20%),需要验证范围 + current := initialBackoff // 100ms + + for i := 0; i < 10; i++ { + next := nextBackoff(current) + + // 退避结果应在 [initialBackoff, maxBackoff] 范围内 + assert.GreaterOrEqual(t, int64(next), int64(initialBackoff), + "第 %d 次退避不应低于初始值 %v", i, initialBackoff) + assert.LessOrEqual(t, int64(next), int64(maxBackoff), + "第 %d 次退避不应超过最大值 %v", i, maxBackoff) + + // 为下一轮提供当前退避值 + current = next + } +} + +func TestNextBackoff_BoundedByMaxBackoff(t *testing.T) { + // 即使输入非常大,输出也不超过 maxBackoff + for i := 0; i < 100; i++ { + result := nextBackoff(10 * time.Second) + assert.LessOrEqual(t, int64(result), int64(maxBackoff), + "退避值不应超过 maxBackoff") + } +} + +func TestNextBackoff_BoundedByInitialBackoff(t *testing.T) { + // 即使输入非常小,输出也不低于 initialBackoff + for i := 0; i < 100; i++ { + result := nextBackoff(1 * time.Millisecond) + assert.GreaterOrEqual(t, int64(result), int64(initialBackoff), + "退避值不应低于 initialBackoff") + } +} + +func TestNextBackoff_HasJitter(t *testing.T) { + // 验证多次调用会产生不同的值(随机抖动生效) + // 使用相同的输入调用 50 次,收集结果 + results := make(map[time.Duration]bool) + current := 500 * time.Millisecond + + for i := 0; i < 50; i++ { + result := nextBackoff(current) + results[result] = true + } + + // 50 次调用应该至少有 2 个不同的值(抖动存在) + require.Greater(t, len(results), 1, + "nextBackoff 应产生随机抖动,但所有 50 次调用结果相同") +} + +func TestNextBackoff_InitialValueGrows(t *testing.T) { + // 验证从初始值开始,退避趋势是增长的 + current := initialBackoff + var sum time.Duration + + runs := 100 + for i := 0; i < runs; i++ { + next := nextBackoff(current) + sum += next + current = next + } + + avg := sum / time.Duration(runs) + // 平均退避时间应大于初始值(因为指数增长 + 上限) + assert.Greater(t, int64(avg), int64(initialBackoff), + "平均退避时间应大于初始退避值") +} + +func TestNextBackoff_ConvergesToMaxBackoff(t *testing.T) { + // 从初始值开始,经过多次退避后应收敛到 maxBackoff 附近 + current := initialBackoff + for i := 0; i < 20; i++ { + current = nextBackoff(current) + } + + // 经过 20 次迭代后,应该已经到达 maxBackoff 区间 + // 由于抖动,允许 ±20% 的范围 + lowerBound := time.Duration(float64(maxBackoff) * 0.8) + assert.GreaterOrEqual(t, int64(current), int64(lowerBound), + "经过多次退避后应收敛到 maxBackoff 附近") +} + +func BenchmarkNextBackoff(b *testing.B) { + current := initialBackoff + for i := 0; i < b.N; i++ { + current = nextBackoff(current) + if current > maxBackoff { + current = initialBackoff + } + } +} diff --git a/backend/internal/handler/gateway_helper_fastpath_test.go b/backend/internal/handler/gateway_helper_fastpath_test.go new file mode 100644 index 00000000..31d489f0 --- /dev/null +++ b/backend/internal/handler/gateway_helper_fastpath_test.go @@ -0,0 +1,122 @@ +package handler + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +type concurrencyCacheMock struct { + acquireUserSlotFn func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) + acquireAccountSlotFn func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) + releaseUserCalled int32 + releaseAccountCalled int32 +} + +func (m *concurrencyCacheMock) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + if m.acquireAccountSlotFn != nil { + return m.acquireAccountSlotFn(ctx, accountID, maxConcurrency, requestID) + } + return false, nil +} + +func (m *concurrencyCacheMock) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { + atomic.AddInt32(&m.releaseAccountCalled, 1) + return nil +} + +func (m *concurrencyCacheMock) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) { + return 0, nil +} + +func (m *concurrencyCacheMock) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { + result := make(map[int64]int, len(accountIDs)) + for _, accountID := range accountIDs { + result[accountID] = 0 + } + return result, nil +} + +func (m *concurrencyCacheMock) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { + return true, nil +} + +func (m *concurrencyCacheMock) DecrementAccountWaitCount(ctx context.Context, accountID int64) error { + return nil +} + +func (m *concurrencyCacheMock) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + return 0, nil +} + +func (m *concurrencyCacheMock) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + if m.acquireUserSlotFn != nil { + return m.acquireUserSlotFn(ctx, userID, maxConcurrency, requestID) + } + return false, nil +} + +func (m *concurrencyCacheMock) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error { + atomic.AddInt32(&m.releaseUserCalled, 1) + return nil +} + +func (m *concurrencyCacheMock) GetUserConcurrency(ctx context.Context, userID int64) (int, error) { + return 0, nil +} + +func (m *concurrencyCacheMock) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { + return true, nil +} + +func (m *concurrencyCacheMock) DecrementWaitCount(ctx context.Context, userID int64) error { + return nil +} + +func (m *concurrencyCacheMock) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { + return map[int64]*service.AccountLoadInfo{}, nil +} + +func (m *concurrencyCacheMock) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) { + return map[int64]*service.UserLoadInfo{}, nil +} + +func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { + return nil +} + +func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) { + cache := &concurrencyCacheMock{ + acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + } + helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second) + + release, acquired, err := helper.TryAcquireUserSlot(context.Background(), 101, 2) + require.NoError(t, err) + require.True(t, acquired) + require.NotNil(t, release) + + release() + require.Equal(t, int32(1), atomic.LoadInt32(&cache.releaseUserCalled)) +} + +func TestConcurrencyHelper_TryAcquireAccountSlot_NotAcquired(t *testing.T) { + cache := &concurrencyCacheMock{ + acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + return false, nil + }, + } + helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second) + + release, acquired, err := helper.TryAcquireAccountSlot(context.Background(), 201, 1) + require.NoError(t, err) + require.False(t, acquired) + require.Nil(t, release) + require.Equal(t, int32(0), atomic.LoadInt32(&cache.releaseAccountCalled)) +} diff --git a/backend/internal/handler/gateway_helper_hotpath_test.go b/backend/internal/handler/gateway_helper_hotpath_test.go new file mode 100644 index 00000000..f8f7eaca --- /dev/null +++ b/backend/internal/handler/gateway_helper_hotpath_test.go @@ -0,0 +1,317 @@ +package handler + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type helperConcurrencyCacheStub struct { + mu sync.Mutex + + accountSeq []bool + userSeq []bool + + accountAcquireCalls int + userAcquireCalls int + accountReleaseCalls int + userReleaseCalls int +} + +func (s *helperConcurrencyCacheStub) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.accountAcquireCalls++ + if len(s.accountSeq) == 0 { + return false, nil + } + v := s.accountSeq[0] + s.accountSeq = s.accountSeq[1:] + return v, nil +} + +func (s *helperConcurrencyCacheStub) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.accountReleaseCalls++ + return nil +} + +func (s *helperConcurrencyCacheStub) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) { + return 0, nil +} + +func (s *helperConcurrencyCacheStub) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { + out := make(map[int64]int, len(accountIDs)) + for _, accountID := range accountIDs { + out[accountID] = 0 + } + return out, nil +} + +func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { + return true, nil +} + +func (s *helperConcurrencyCacheStub) DecrementAccountWaitCount(ctx context.Context, accountID int64) error { + return nil +} + +func (s *helperConcurrencyCacheStub) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + return 0, nil +} + +func (s *helperConcurrencyCacheStub) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.userAcquireCalls++ + if len(s.userSeq) == 0 { + return false, nil + } + v := s.userSeq[0] + s.userSeq = s.userSeq[1:] + return v, nil +} + +func (s *helperConcurrencyCacheStub) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.userReleaseCalls++ + return nil +} + +func (s *helperConcurrencyCacheStub) GetUserConcurrency(ctx context.Context, userID int64) (int, error) { + return 0, nil +} + +func (s *helperConcurrencyCacheStub) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { + return true, nil +} + +func (s *helperConcurrencyCacheStub) DecrementWaitCount(ctx context.Context, userID int64) error { + return nil +} + +func (s *helperConcurrencyCacheStub) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { + out := make(map[int64]*service.AccountLoadInfo, len(accounts)) + for _, acc := range accounts { + out[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID} + } + return out, nil +} + +func (s *helperConcurrencyCacheStub) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) { + out := make(map[int64]*service.UserLoadInfo, len(users)) + for _, user := range users { + out[user.ID] = &service.UserLoadInfo{UserID: user.ID} + } + return out, nil +} + +func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { + return nil +} + +func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(method, path, nil) + return c, rec +} + +func validClaudeCodeBodyJSON() []byte { + return []byte(`{ + "model":"claude-3-5-sonnet-20241022", + "system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}], + "metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"} + }`) +} + +func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) { + t.Run("non_cli_user_agent_sets_false", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + c.Request.Header.Set("User-Agent", "curl/8.6.0") + + SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil) + require.False(t, service.IsClaudeCodeClient(c.Request.Context())) + }) + + t.Run("cli_non_messages_path_sets_true", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodGet, "/v1/models") + c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") + + SetClaudeCodeClientContext(c, nil, nil) + require.True(t, service.IsClaudeCodeClient(c.Request.Context())) + }) + + t.Run("cli_messages_path_valid_body_sets_true", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") + c.Request.Header.Set("X-App", "claude-code") + c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24") + c.Request.Header.Set("anthropic-version", "2023-06-01") + + SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil) + require.True(t, service.IsClaudeCodeClient(c.Request.Context())) + }) + + t.Run("cli_messages_path_invalid_body_sets_false", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") + // 缺少严格校验所需 header + body 字段 + SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`), nil) + require.False(t, service.IsClaudeCodeClient(c.Request.Context())) + }) +} + +func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing.T) { + t.Run("reuse parsed request without body unmarshal", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") + c.Request.Header.Set("X-App", "claude-code") + c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24") + c.Request.Header.Set("anthropic-version", "2023-06-01") + + parsedReq := &service.ParsedRequest{ + Model: "claude-3-5-sonnet-20241022", + System: []any{ + map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."}, + }, + MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123", + } + + // body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。 + SetClaudeCodeClientContext(c, []byte(`{invalid`), parsedReq) + require.True(t, service.IsClaudeCodeClient(c.Request.Context())) + }) + + t.Run("reuse context cache without body unmarshal", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") + c.Request.Header.Set("X-App", "claude-code") + c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24") + c.Request.Header.Set("anthropic-version", "2023-06-01") + c.Set(service.OpenAIParsedRequestBodyKey, map[string]any{ + "model": "claude-3-5-sonnet-20241022", + "system": []any{ + map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."}, + }, + "metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}, + }) + + SetClaudeCodeClientContext(c, []byte(`{invalid`), nil) + require.True(t, service.IsClaudeCodeClient(c.Request.Context())) + }) +} + +func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) { + cache := &helperConcurrencyCacheStub{ + accountSeq: []bool{false, true}, + userSeq: []bool{false, true}, + } + concurrency := service.NewConcurrencyService(cache) + helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) + + t.Run("account_slot_acquired_after_retry", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted, true) + require.NoError(t, err) + require.NotNil(t, release) + require.False(t, streamStarted) + release() + require.GreaterOrEqual(t, cache.accountAcquireCalls, 2) + require.GreaterOrEqual(t, cache.accountReleaseCalls, 1) + }) + + t.Run("user_slot_acquired_after_retry", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + release, err := helper.waitForSlotWithPingTimeout(c, "user", 202, 3, time.Second, false, &streamStarted, true) + require.NoError(t, err) + require.NotNil(t, release) + release() + require.GreaterOrEqual(t, cache.userAcquireCalls, 2) + require.GreaterOrEqual(t, cache.userReleaseCalls, 1) + }) +} + +func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) { + cache := &helperConcurrencyCacheStub{ + accountSeq: []bool{false, false, false}, + } + concurrency := service.NewConcurrencyService(cache) + + t.Run("timeout_returns_concurrency_error", func(t *testing.T) { + helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 130*time.Millisecond, false, &streamStarted, true) + require.Nil(t, release) + var cErr *ConcurrencyError + require.ErrorAs(t, err, &cErr) + require.True(t, cErr.IsTimeout) + }) + + t.Run("stream_mode_sends_ping_before_timeout", func(t *testing.T) { + helper := NewConcurrencyHelper(concurrency, SSEPingFormatComment, 10*time.Millisecond) + c, rec := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 70*time.Millisecond, true, &streamStarted, true) + require.Nil(t, release) + var cErr *ConcurrencyError + require.ErrorAs(t, err, &cErr) + require.True(t, cErr.IsTimeout) + require.True(t, streamStarted) + require.Contains(t, rec.Body.String(), ":\n\n") + }) +} + +func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) { + errCache := &helperConcurrencyCacheStubWithError{ + err: errors.New("redis unavailable"), + } + concurrency := service.NewConcurrencyService(errCache) + helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + release, err := helper.waitForSlotWithPingTimeout(c, "account", 1, 1, 200*time.Millisecond, false, &streamStarted, true) + require.Nil(t, release) + require.Error(t, err) + require.Contains(t, err.Error(), "redis unavailable") +} + +func TestAcquireAccountSlotWithWaitTimeout_ImmediateAttemptBeforeBackoff(t *testing.T) { + cache := &helperConcurrencyCacheStub{ + accountSeq: []bool{false}, + } + concurrency := service.NewConcurrencyService(cache) + helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + + release, err := helper.AcquireAccountSlotWithWaitTimeout(c, 301, 1, 30*time.Millisecond, false, &streamStarted) + require.Nil(t, release) + var cErr *ConcurrencyError + require.ErrorAs(t, err, &cErr) + require.True(t, cErr.IsTimeout) + require.GreaterOrEqual(t, cache.accountAcquireCalls, 1) +} + +type helperConcurrencyCacheStubWithError struct { + helperConcurrencyCacheStub + err error +} + +func (s *helperConcurrencyCacheStubWithError) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + return false, s.err +} diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 3d25505b..50af9c8f 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -7,24 +7,23 @@ import ( "encoding/hex" "encoding/json" "errors" - "io" - "log" "net/http" "regexp" "strings" - "time" "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" - "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/gemini" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/google/uuid" "github.com/gin-gonic/gin" + "go.uber.org/zap" ) // geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值 @@ -143,6 +142,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { googleError(c, http.StatusInternalServerError, "User context not found") return } + reqLog := requestLogger( + c, + "handler.gemini_v1beta.models", + zap.Int64("user_id", authSubject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) // 检查平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则要求 gemini 分组 if !middleware.HasForcePlatform(c) { @@ -159,8 +165,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } stream := action == "streamGenerateContent" + reqLog = reqLog.With(zap.String("model", modelName), zap.String("action", action), zap.Bool("stream", stream)) - body, err := io.ReadAll(c.Request.Body) + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit)) @@ -187,8 +194,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait) waitCounted := false if err != nil { - log.Printf("Increment wait count failed: %v", err) + reqLog.Warn("gemini.user_wait_counter_increment_failed", zap.Error(err)) } else if !canWait { + reqLog.Info("gemini.user_wait_queue_full", zap.Int("max_wait", maxWait)) googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later") return } @@ -208,6 +216,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted) if err != nil { + reqLog.Warn("gemini.user_slot_acquire_failed", zap.Error(err)) googleError(c, http.StatusTooManyRequests, err.Error()) return } @@ -223,6 +232,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 2) billing eligibility check (after wait) if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err)) status, _, message := billingErrorDetails(err) googleError(c, status, message) return @@ -252,6 +262,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { var sessionBoundAccountID int64 if sessionKey != "" { sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) + if sessionBoundAccountID > 0 { + prefetchedGroupID := int64(0) + if apiKey.GroupID != nil { + prefetchedGroupID = *apiKey.GroupID + } + ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled()) + c.Request = c.Request.WithContext(ctx) + } } // === Gemini 内容摘要会话 Fallback 逻辑 === @@ -296,8 +314,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { matchedDigestChain = foundMatchedChain sessionBoundAccountID = foundAccountID geminiSessionUUID = foundUUID - log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s", - safeShortPrefix(foundUUID, 8), foundAccountID, truncateDigestChain(geminiDigestChain)) + reqLog.Info("gemini.digest_fallback_matched", + zap.String("session_uuid_prefix", safeShortPrefix(foundUUID, 8)), + zap.Int64("account_id", foundAccountID), + zap.String("digest_chain", truncateDigestChain(geminiDigestChain)), + ) // 关键:如果原 sessionKey 为空,使用 prefixHash + uuid 作为 sessionKey // 这样 SelectAccountWithLoadAwareness 的粘性会话逻辑会优先使用匹配到的账号 @@ -321,55 +342,54 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 cleanedForUnknownBinding := false - maxAccountSwitches := h.maxAccountSwitchesGemini - switchCount := 0 - failedAccountIDs := make(map[int64]struct{}) - var lastFailoverErr *service.UpstreamFailoverError - var forceCacheBilling bool // 粘性会话切换时的缓存计费标记 + fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession) // 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。 // 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。 if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) { - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) } for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制 + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "") // Gemini 不使用会话限制 if err != nil { - if len(failedAccountIDs) == 0 { + if len(fs.FailedAccountIDs) == 0 { googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) return } - // Antigravity 单账号退避重试:分组内没有其他可用账号时, - // 对 503 错误不直接返回,而是清除排除列表、等待退避后重试同一个账号。 - // 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 - if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches { - if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) { - log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches) - failedAccountIDs = make(map[int64]struct{}) - // 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换 - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) - c.Request = c.Request.WithContext(ctx) - continue - } + action := fs.HandleSelectionExhausted(c.Request.Context()) + switch action { + case FailoverContinue: + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) + c.Request = c.Request.WithContext(ctx) + continue + case FailoverCanceled: + return + default: // FailoverExhausted + h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr) + return } - h.handleGeminiFailoverExhausted(c, lastFailoverErr) - return } account := selection.Account - setOpsSelectedAccount(c, account.ID) + setOpsSelectedAccount(c, account.ID, account.Platform) // 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature // 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。 if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID { - log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID) + reqLog.Info("gemini.sticky_session_account_switched", + zap.Int64("from_account_id", sessionBoundAccountID), + zap.Int64("to_account_id", account.ID), + zap.Bool("clean_thought_signature", true), + ) body = service.CleanGeminiNativeThoughtSignatures(body) sessionBoundAccountID = account.ID } else if sessionKey != "" && sessionBoundAccountID == 0 && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) { // 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,客户端继续携带旧签名。 // 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。 - log.Printf("[Gemini] Sticky session binding missing, cleaning thoughtSignature proactively") + reqLog.Info("gemini.sticky_session_binding_missing", + zap.Bool("clean_thought_signature", true), + ) body = service.CleanGeminiNativeThoughtSignatures(body) cleanedForUnknownBinding = true sessionBoundAccountID = account.ID @@ -388,9 +408,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { accountWaitCounted := false canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) if err != nil { - log.Printf("Increment account wait count failed: %v", err) + reqLog.Warn("gemini.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) + reqLog.Info("gemini.account_wait_queue_full", + zap.Int64("account_id", account.ID), + zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), + ) googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later") return } @@ -412,6 +435,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { &streamStarted, ) if err != nil { + reqLog.Warn("gemini.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) googleError(c, http.StatusTooManyRequests, err.Error()) return } @@ -420,7 +444,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { accountWaitCounted = false } if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) + reqLog.Warn("gemini.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } } // 账号槽位/等待计数需要在超时或断开时安全回收 @@ -429,8 +453,8 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 5) forward (根据平台分流) var result *service.ForwardResult requestCtx := c.Request.Context() - if switchCount > 0 { - requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + if fs.SwitchCount > 0 { + requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled()) } if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession) @@ -443,27 +467,19 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { - failedAccountIDs[account.ID] = struct{}{} - if needForceCacheBilling(hasBoundSession, failoverErr) { - forceCacheBilling = true - } - if switchCount >= maxAccountSwitches { - lastFailoverErr = failoverErr - h.handleGeminiFailoverExhausted(c, lastFailoverErr) + failoverAction := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) + switch failoverAction { + case FailoverContinue: + continue + case FailoverExhausted: + h.handleGeminiFailoverExhausted(c, fs.LastFailoverErr) + return + case FailoverCanceled: return } - lastFailoverErr = failoverErr - switchCount++ - log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) - if account.Platform == service.PlatformAntigravity { - if !sleepFailoverDelay(c.Request.Context(), switchCount) { - return - } - } - continue } // ForwardNative already wrote the response - log.Printf("Gemini native forward failed: %v", err) + reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err)) return } @@ -482,31 +498,39 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { account.ID, matchedDigestChain, ); err != nil { - log.Printf("[Gemini] Failed to save digest session: %v", err) + reqLog.Warn("gemini.digest_session_save_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } } - // 6) record usage async (Gemini 使用长上下文双倍计费) - go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string, fcb bool) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - + // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{ Result: result, APIKey: apiKey, User: apiKey.User, - Account: usedAccount, + Account: account, Subscription: subscription, - UserAgent: ua, - IPAddress: ip, + UserAgent: userAgent, + IPAddress: clientIP, LongContextThreshold: 200000, // Gemini 200K 阈值 LongContextMultiplier: 2.0, // 超出部分双倍计费 - ForceCacheBilling: fcb, + ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, }); err != nil { - log.Printf("Record usage failed: %v", err) + logger.L().With( + zap.String("component", "handler.gemini_v1beta.models"), + zap.Int64("user_id", authSubject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", modelName), + zap.Int64("account_id", account.ID), + ).Error("gemini.record_usage_failed", zap.Error(err)) } - }(result, account, userAgent, clientIP, forceCacheBilling) + }) + reqLog.Debug("gemini.request_completed", + zap.Int64("account_id", account.ID), + zap.Int("switch_count", fs.SwitchCount), + ) return } } diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index b2b12c0d..1e1247fc 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -11,6 +11,7 @@ type AdminHandlers struct { Group *admin.GroupHandler Account *admin.AccountHandler Announcement *admin.AnnouncementHandler + DataManagement *admin.DataManagementHandler OAuth *admin.OAuthHandler OpenAIOAuth *admin.OpenAIOAuthHandler GeminiOAuth *admin.GeminiOAuthHandler @@ -25,6 +26,7 @@ type AdminHandlers struct { Usage *admin.UsageHandler UserAttribute *admin.UserAttributeHandler ErrorPassthrough *admin.ErrorPassthroughHandler + APIKey *admin.AdminAPIKeyHandler } // Handlers contains all HTTP handlers @@ -39,6 +41,8 @@ type Handlers struct { Admin *AdminHandlers Gateway *GatewayHandler OpenAIGateway *OpenAIGatewayHandler + SoraGateway *SoraGatewayHandler + SoraClient *SoraClientHandler Setting *SettingHandler Totp *TotpHandler } diff --git a/backend/internal/handler/idempotency_helper.go b/backend/internal/handler/idempotency_helper.go new file mode 100644 index 00000000..bca63b6b --- /dev/null +++ b/backend/internal/handler/idempotency_helper.go @@ -0,0 +1,65 @@ +package handler + +import ( + "context" + "strconv" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +func executeUserIdempotentJSON( + c *gin.Context, + scope string, + payload any, + ttl time.Duration, + execute func(context.Context) (any, error), +) { + coordinator := service.DefaultIdempotencyCoordinator() + if coordinator == nil { + data, err := execute(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, data) + return + } + + actorScope := "user:0" + if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok { + actorScope = "user:" + strconv.FormatInt(subject.UserID, 10) + } + + result, err := coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{ + Scope: scope, + ActorScope: actorScope, + Method: c.Request.Method, + Route: c.FullPath(), + IdempotencyKey: c.GetHeader("Idempotency-Key"), + Payload: payload, + RequireKey: true, + TTL: ttl, + }, execute) + if err != nil { + if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) { + service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_fail_close") + logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=fail_close", c.Request.Method, c.FullPath(), scope) + } + if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } + response.ErrorFrom(c, err) + return + } + if result != nil && result.Replayed { + c.Header("X-Idempotency-Replayed", "true") + } + response.Success(c, result.Data) +} diff --git a/backend/internal/handler/idempotency_helper_test.go b/backend/internal/handler/idempotency_helper_test.go new file mode 100644 index 00000000..e8213a2b --- /dev/null +++ b/backend/internal/handler/idempotency_helper_test.go @@ -0,0 +1,285 @@ +package handler + +import ( + "bytes" + "context" + "errors" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type userStoreUnavailableRepoStub struct{} + +func (userStoreUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) { + return false, errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) { + return nil, errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error { + return errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) { + return 0, errors.New("store unavailable") +} + +type userMemoryIdempotencyRepoStub struct { + mu sync.Mutex + nextID int64 + data map[string]*service.IdempotencyRecord +} + +func newUserMemoryIdempotencyRepoStub() *userMemoryIdempotencyRepoStub { + return &userMemoryIdempotencyRepoStub{ + nextID: 1, + data: make(map[string]*service.IdempotencyRecord), + } +} + +func (r *userMemoryIdempotencyRepoStub) key(scope, keyHash string) string { + return scope + "|" + keyHash +} + +func (r *userMemoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord { + if in == nil { + return nil + } + out := *in + if in.LockedUntil != nil { + v := *in.LockedUntil + out.LockedUntil = &v + } + if in.ResponseBody != nil { + v := *in.ResponseBody + out.ResponseBody = &v + } + if in.ResponseStatus != nil { + v := *in.ResponseStatus + out.ResponseStatus = &v + } + if in.ErrorReason != nil { + v := *in.ErrorReason + out.ErrorReason = &v + } + return &out +} + +func (r *userMemoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + k := r.key(record.Scope, record.IdempotencyKeyHash) + if _, ok := r.data[k]; ok { + return false, nil + } + cp := r.clone(record) + cp.ID = r.nextID + r.nextID++ + r.data[k] = cp + record.ID = cp.ID + return true, nil +} + +func (r *userMemoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) { + r.mu.Lock() + defer r.mu.Unlock() + return r.clone(r.data[r.key(scope, keyHash)]), nil +} + +func (r *userMemoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != fromStatus { + return false, nil + } + if rec.LockedUntil != nil && rec.LockedUntil.After(now) { + return false, nil + } + rec.Status = service.IdempotencyStatusProcessing + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + rec.ErrorReason = nil + return true, nil + } + return false, nil +} + +func (r *userMemoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint { + return false, nil + } + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + return true, nil + } + return false, nil +} + +func (r *userMemoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = service.IdempotencyStatusSucceeded + rec.LockedUntil = nil + rec.ExpiresAt = expiresAt + rec.ResponseStatus = &responseStatus + rec.ResponseBody = &responseBody + rec.ErrorReason = nil + return nil + } + return nil +} + +func (r *userMemoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = service.IdempotencyStatusFailedRetryable + rec.LockedUntil = &lockedUntil + rec.ExpiresAt = expiresAt + rec.ErrorReason = &errorReason + return nil + } + return nil +} + +func (r *userMemoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) { + return 0, nil +} + +func withUserSubject(userID int64) gin.HandlerFunc { + return func(c *gin.Context) { + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID}) + c.Next() + } +} + +func TestExecuteUserIdempotentJSONFallbackWithoutCoordinator(t *testing.T) { + gin.SetMode(gin.TestMode) + service.SetDefaultIdempotencyCoordinator(nil) + + var executed int + router := gin.New() + router.Use(withUserSubject(1)) + router.POST("/idempotent", func(c *gin.Context) { + executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed++ + return gin.H{"ok": true}, nil + }) + }) + + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, 1, executed) +} + +func TestExecuteUserIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) { + gin.SetMode(gin.TestMode) + service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(userStoreUnavailableRepoStub{}, service.DefaultIdempotencyConfig())) + t.Cleanup(func() { + service.SetDefaultIdempotencyCoordinator(nil) + }) + + var executed int + router := gin.New() + router.Use(withUserSubject(2)) + router.POST("/idempotent", func(c *gin.Context) { + executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed++ + return gin.H{"ok": true}, nil + }) + }) + + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", "k1") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusServiceUnavailable, rec.Code) + require.Equal(t, 0, executed) +} + +func TestExecuteUserIdempotentJSONConcurrentRetrySingleSideEffectAndReplay(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := newUserMemoryIdempotencyRepoStub() + cfg := service.DefaultIdempotencyConfig() + cfg.ProcessingTimeout = 2 * time.Second + service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg)) + t.Cleanup(func() { + service.SetDefaultIdempotencyCoordinator(nil) + }) + + var executed atomic.Int32 + router := gin.New() + router.Use(withUserSubject(3)) + router.POST("/idempotent", func(c *gin.Context) { + executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed.Add(1) + time.Sleep(80 * time.Millisecond) + return gin.H{"ok": true}, nil + }) + }) + + call := func() (int, http.Header) { + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", "same-user-key") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + return rec.Code, rec.Header() + } + + var status1, status2 int + var wg sync.WaitGroup + wg.Add(2) + go func() { defer wg.Done(); status1, _ = call() }() + go func() { defer wg.Done(); status2, _ = call() }() + wg.Wait() + + require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1) + require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2) + require.Equal(t, int32(1), executed.Load()) + + status3, headers3 := call() + require.Equal(t, http.StatusOK, status3) + require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed")) + require.Equal(t, int32(1), executed.Load()) +} diff --git a/backend/internal/handler/logging.go b/backend/internal/handler/logging.go new file mode 100644 index 00000000..2d5e6e22 --- /dev/null +++ b/backend/internal/handler/logging.go @@ -0,0 +1,19 @@ +package handler + +import ( + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +func requestLogger(c *gin.Context, component string, fields ...zap.Field) *zap.Logger { + base := logger.L() + if c != nil && c.Request != nil { + base = logger.FromContext(c.Request.Context()) + } + + if component != "" { + fields = append([]zap.Field{zap.String("component", component)}, fields...) + } + return base.With(fields...) +} diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index c08a8b0e..4bbd17ba 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -5,19 +5,23 @@ import ( "encoding/json" "errors" "fmt" - "io" - "log" "net/http" + "runtime/debug" + "strconv" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" - "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" + coderws "github.com/coder/websocket" "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" ) // OpenAIGatewayHandler handles OpenAI API gateway requests @@ -25,6 +29,7 @@ type OpenAIGatewayHandler struct { gatewayService *service.OpenAIGatewayService billingCacheService *service.BillingCacheService apiKeyService *service.APIKeyService + usageRecordWorkerPool *service.UsageRecordWorkerPool errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int @@ -36,6 +41,7 @@ func NewOpenAIGatewayHandler( concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, apiKeyService *service.APIKeyService, + usageRecordWorkerPool *service.UsageRecordWorkerPool, errorPassthroughService *service.ErrorPassthroughService, cfg *config.Config, ) *OpenAIGatewayHandler { @@ -51,6 +57,7 @@ func NewOpenAIGatewayHandler( gatewayService: gatewayService, billingCacheService: billingCacheService, apiKeyService: apiKeyService, + usageRecordWorkerPool: usageRecordWorkerPool, errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), maxAccountSwitches: maxAccountSwitches, @@ -60,6 +67,13 @@ func NewOpenAIGatewayHandler( // Responses handles OpenAI Responses API endpoint // POST /openai/v1/responses func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { + // 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。 + streamStarted := false + defer h.recoverResponsesPanic(c, &streamStarted) + setOpenAIClientTransportHTTP(c) + + requestStart := time.Now() + // Get apiKey and user from context (set by ApiKeyAuth middleware) apiKey, ok := middleware2.GetAPIKeyFromContext(c) if !ok { @@ -72,9 +86,19 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") return } + reqLog := requestLogger( + c, + "handler.openai_gateway.responses", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + if !h.ensureResponsesDependencies(c, reqLog) { + return + } // Read request body - body, err := io.ReadAll(c.Request.Body) + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) @@ -91,64 +115,51 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, "", false, body) - // Parse request body to map for potential modification - var reqBody map[string]any - if err := json.Unmarshal(body, &reqBody); err != nil { + // 校验请求体 JSON 合法性 + if !gjson.ValidBytes(body) { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return } - // Extract model and stream - reqModel, _ := reqBody["model"].(string) - reqStream, _ := reqBody["stream"].(bool) - - // 验证 model 必填 - if reqModel == "" { + // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") return } + reqModel := modelResult.String() - userAgent := c.GetHeader("User-Agent") - if !openai.IsCodexCLIRequest(userAgent) { - existingInstructions, _ := reqBody["instructions"].(string) - if strings.TrimSpace(existingInstructions) == "" { - if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" { - reqBody["instructions"] = instructions - // Re-serialize body - body, err = json.Marshal(reqBody) - if err != nil { - h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request") - return - } - } + streamResult := gjson.GetBytes(body, "stream") + if streamResult.Exists() && streamResult.Type != gjson.True && streamResult.Type != gjson.False { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "invalid stream field type") + return + } + reqStream := streamResult.Bool() + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) + previousResponseID := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String()) + if previousResponseID != "" { + previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID) + reqLog = reqLog.With( + zap.Bool("has_previous_response_id", true), + zap.String("previous_response_id_kind", previousResponseIDKind), + zap.Int("previous_response_id_len", len(previousResponseID)), + ) + if previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID { + reqLog.Warn("openai.request_validation_failed", + zap.String("reason", "previous_response_id_looks_like_message_id"), + ) + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id") + return } } setOpsRequestContext(c, reqModel, reqStream, body) // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 - // 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call, - // 或带 id 且与 call_id 匹配的 item_reference。 - if service.HasFunctionCallOutput(reqBody) { - previousResponseID, _ := reqBody["previous_response_id"].(string) - if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) { - if service.HasFunctionCallOutputMissingCallID(reqBody) { - log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel) - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id") - return - } - callIDs := service.FunctionCallOutputCallIDs(reqBody) - if !service.HasItemReferenceForCallIDs(reqBody, callIDs) { - log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel) - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id") - return - } - } + if !h.validateFunctionCallOutputRequest(c, body, reqLog) { + return } - // Track if we've started streaming (for error handling) - streamStarted := false - // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) @@ -157,54 +168,28 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // Get subscription info (may be nil) subscription, _ := middleware2.GetSubscriptionFromContext(c) - // 0. Check if wait queue is full - maxWait := service.CalculateMaxWait(subject.Concurrency) - canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) - waitCounted := false - if err != nil { - log.Printf("Increment wait count failed: %v", err) - // On error, allow request to proceed - } else if !canWait { - h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") - return - } - if err == nil && canWait { - waitCounted = true - } - defer func() { - if waitCounted { - h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) - } - }() + service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) + routingStart := time.Now() - // 1. First acquire user concurrency slot - userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) - if err != nil { - log.Printf("User concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "user", streamStarted) + userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog) + if !acquired { return } - // User slot acquired: no longer waiting. - if waitCounted { - h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) - waitCounted = false - } // 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏 - userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) if userReleaseFunc != nil { defer userReleaseFunc() } // 2. Re-check billing eligibility after wait if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { - log.Printf("Billing eligibility check failed after wait: %v", err) + reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err)) status, code, message := billingErrorDetails(err) h.handleStreamingAwareError(c, status, code, message, streamStarted) return } // Generate session hash (header first; fallback to prompt_cache_key) - sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody) + sessionHash := h.gatewayService.GenerateSessionHash(c, body) maxAccountSwitches := h.maxAccountSwitches switchCount := 0 @@ -213,12 +198,23 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { for { // Select account supporting the requested model - log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel) - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) + reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + c.Request.Context(), + apiKey.GroupID, + previousResponseID, + sessionHash, + reqModel, + failedAccountIDs, + service.OpenAIUpstreamTransportAny, + ) if err != nil { - log.Printf("[OpenAI Handler] SelectAccount failed: %v", err) + reqLog.Warn("openai.account_select_failed", + zap.Error(err), + zap.Int("excluded_account_count", len(failedAccountIDs)), + ) if len(failedAccountIDs) == 0 { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) return } if lastFailoverErr != nil { @@ -228,67 +224,53 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } return } - account := selection.Account - log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name) - setOpsSelectedAccount(c, account.ID) - - // 3. Acquire account concurrency slot - accountReleaseFunc := selection.ReleaseFunc - if !selection.Acquired { - if selection.WaitPlan == nil { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) - return - } - accountWaitCounted := false - canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) - if err != nil { - log.Printf("Increment account wait count failed: %v", err) - } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) - return - } - if err == nil && canWait { - accountWaitCounted = true - } - defer func() { - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - } - }() - - accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( - c, - account.ID, - selection.WaitPlan.MaxConcurrency, - selection.WaitPlan.Timeout, - reqStream, - &streamStarted, - ) - if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } - if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) - } + if selection == nil || selection.Account == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + if previousResponseID != "" && selection != nil && selection.Account != nil { + reqLog.Debug("openai.account_selected_with_previous_response_id", zap.Int64("account_id", selection.Account.ID)) + } + reqLog.Debug("openai.account_schedule_decision", + zap.String("layer", scheduleDecision.Layer), + zap.Bool("sticky_previous_hit", scheduleDecision.StickyPreviousHit), + zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit), + zap.Int("candidate_count", scheduleDecision.CandidateCount), + zap.Int("top_k", scheduleDecision.TopK), + zap.Int64("latency_ms", scheduleDecision.LatencyMs), + zap.Float64("load_skew", scheduleDecision.LoadSkew), + ) + account := selection.Account + reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name)) + setOpsSelectedAccount(c, account.ID, account.Platform) + + accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog) + if !acquired { + return } - // 账号槽位/等待计数需要在超时或断开时安全回收 - accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) // Forward request + service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) + forwardStart := time.Now() result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) + forwardDurationMs := time.Since(forwardStart).Milliseconds() if accountReleaseFunc != nil { accountReleaseFunc() } + upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) + responseLatencyMs := forwardDurationMs + if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { + responseLatencyMs = forwardDurationMs - upstreamLatencyMs + } + service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs) + if err == nil && result != nil && result.FirstTokenMs != nil { + service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) + } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + h.gatewayService.RecordOpenAIAccountSwitch() failedAccountIDs[account.ID] = struct{}{} lastFailoverErr = failoverErr if switchCount >= maxAccountSwitches { @@ -296,39 +278,631 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { return } switchCount++ - log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + reqLog.Warn("openai.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) continue } - // Error response already handled in Forward, just log - log.Printf("Account %d: Forward request failed: %v", account.ID, err) + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + fields := []zap.Field{ + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + } + if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) { + reqLog.Warn("openai.forward_failed", fields...) + return + } + reqLog.Error("openai.forward_failed", fields...) return } + if result != nil { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) + } else { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) + } // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) - // Async record usage - go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua, ip string) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, - Account: usedAccount, + Account: account, Subscription: subscription, - UserAgent: ua, - IPAddress: ip, + UserAgent: userAgent, + IPAddress: clientIP, APIKeyService: h.apiKeyService, }); err != nil { - log.Printf("Record usage failed: %v", err) + logger.L().With( + zap.String("component", "handler.openai_gateway.responses"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", reqModel), + zap.Int64("account_id", account.ID), + ).Error("openai.record_usage_failed", zap.Error(err)) } - }(result, account, userAgent, clientIP) + }) + reqLog.Debug("openai.request_completed", + zap.Int64("account_id", account.ID), + zap.Int("switch_count", switchCount), + ) return } } +func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool { + if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() { + return true + } + + var reqBody map[string]any + if err := json.Unmarshal(body, &reqBody); err != nil { + // 保持原有容错语义:解析失败时跳过预校验,沿用后续上游校验结果。 + return true + } + + c.Set(service.OpenAIParsedRequestBodyKey, reqBody) + validation := service.ValidateFunctionCallOutputContext(reqBody) + if !validation.HasFunctionCallOutput { + return true + } + + previousResponseID, _ := reqBody["previous_response_id"].(string) + if strings.TrimSpace(previousResponseID) != "" || validation.HasToolCallContext { + return true + } + + if validation.HasFunctionCallOutputMissingCallID { + reqLog.Warn("openai.request_validation_failed", + zap.String("reason", "function_call_output_missing_call_id"), + ) + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id") + return false + } + if validation.HasItemReferenceForAllCallIDs { + return true + } + + reqLog.Warn("openai.request_validation_failed", + zap.String("reason", "function_call_output_missing_item_reference"), + ) + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id") + return false +} + +func (h *OpenAIGatewayHandler) acquireResponsesUserSlot( + c *gin.Context, + userID int64, + userConcurrency int, + reqStream bool, + streamStarted *bool, + reqLog *zap.Logger, +) (func(), bool) { + ctx := c.Request.Context() + userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, userID, userConcurrency) + if err != nil { + reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err)) + h.handleConcurrencyError(c, err, "user", *streamStarted) + return nil, false + } + if userAcquired { + return wrapReleaseOnDone(ctx, userReleaseFunc), true + } + + maxWait := service.CalculateMaxWait(userConcurrency) + canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(ctx, userID, maxWait) + if waitErr != nil { + reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr)) + // 按现有降级语义:等待计数异常时放行后续抢槽流程 + } else if !canWait { + reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait)) + h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") + return nil, false + } + + waitCounted := waitErr == nil && canWait + defer func() { + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(ctx, userID) + } + }() + + userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, userID, userConcurrency, reqStream, streamStarted) + if err != nil { + reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err)) + h.handleConcurrencyError(c, err, "user", *streamStarted) + return nil, false + } + + // 槽位获取成功后,立刻退出等待计数。 + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(ctx, userID) + waitCounted = false + } + return wrapReleaseOnDone(ctx, userReleaseFunc), true +} + +func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot( + c *gin.Context, + groupID *int64, + sessionHash string, + selection *service.AccountSelectionResult, + reqStream bool, + streamStarted *bool, + reqLog *zap.Logger, +) (func(), bool) { + if selection == nil || selection.Account == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted) + return nil, false + } + + ctx := c.Request.Context() + account := selection.Account + if selection.Acquired { + return wrapReleaseOnDone(ctx, selection.ReleaseFunc), true + } + if selection.WaitPlan == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted) + return nil, false + } + + fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot( + ctx, + account.ID, + selection.WaitPlan.MaxConcurrency, + ) + if err != nil { + reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + h.handleConcurrencyError(c, err, "account", *streamStarted) + return nil, false + } + if fastAcquired { + if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil { + reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + return wrapReleaseOnDone(ctx, fastReleaseFunc), true + } + + canWait, waitErr := h.concurrencyHelper.IncrementAccountWaitCount(ctx, account.ID, selection.WaitPlan.MaxWaiting) + if waitErr != nil { + reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(waitErr)) + } else if !canWait { + reqLog.Info("openai.account_wait_queue_full", + zap.Int64("account_id", account.ID), + zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), + ) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", *streamStarted) + return nil, false + } + + accountWaitCounted := waitErr == nil && canWait + releaseWait := func() { + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(ctx, account.ID) + accountWaitCounted = false + } + } + defer releaseWait() + + accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + streamStarted, + ) + if err != nil { + reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + h.handleConcurrencyError(c, err, "account", *streamStarted) + return nil, false + } + + // Slot acquired: no longer waiting in queue. + releaseWait() + if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil { + reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + return wrapReleaseOnDone(ctx, accountReleaseFunc), true +} + +// ResponsesWebSocket handles OpenAI Responses API WebSocket ingress endpoint +// GET /openai/v1/responses (Upgrade: websocket) +func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { + if !isOpenAIWSUpgradeRequest(c.Request) { + h.errorResponse(c, http.StatusUpgradeRequired, "invalid_request_error", "WebSocket upgrade required (Upgrade: websocket)") + return + } + setOpenAIClientTransportWS(c) + + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + + reqLog := requestLogger( + c, + "handler.openai_gateway.responses_ws", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.Bool("openai_ws_mode", true), + ) + if !h.ensureResponsesDependencies(c, reqLog) { + return + } + reqLog.Info("openai.websocket_ingress_started") + clientIP := ip.GetClientIP(c) + userAgent := strings.TrimSpace(c.GetHeader("User-Agent")) + + wsConn, err := coderws.Accept(c.Writer, c.Request, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + reqLog.Warn("openai.websocket_accept_failed", + zap.Error(err), + zap.String("client_ip", clientIP), + zap.String("request_user_agent", userAgent), + zap.String("upgrade_header", strings.TrimSpace(c.GetHeader("Upgrade"))), + zap.String("connection_header", strings.TrimSpace(c.GetHeader("Connection"))), + zap.String("sec_websocket_version", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Version"))), + zap.Bool("has_sec_websocket_key", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Key")) != ""), + ) + return + } + defer func() { + _ = wsConn.CloseNow() + }() + wsConn.SetReadLimit(16 * 1024 * 1024) + + ctx := c.Request.Context() + readCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + msgType, firstMessage, err := wsConn.Read(readCtx) + cancel() + if err != nil { + closeStatus, closeReason := summarizeWSCloseErrorForLog(err) + reqLog.Warn("openai.websocket_read_first_message_failed", + zap.Error(err), + zap.String("client_ip", clientIP), + zap.String("close_status", closeStatus), + zap.String("close_reason", closeReason), + zap.Duration("read_timeout", 30*time.Second), + ) + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "missing first response.create message") + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "unsupported websocket message type") + return + } + if !gjson.ValidBytes(firstMessage) { + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "invalid JSON payload") + return + } + + reqModel := strings.TrimSpace(gjson.GetBytes(firstMessage, "model").String()) + if reqModel == "" { + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model is required in first response.create payload") + return + } + previousResponseID := strings.TrimSpace(gjson.GetBytes(firstMessage, "previous_response_id").String()) + previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID) + if previousResponseID != "" && previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID { + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "previous_response_id must be a response.id (resp_*), not a message id") + return + } + reqLog = reqLog.With( + zap.Bool("ws_ingress", true), + zap.String("model", reqModel), + zap.Bool("has_previous_response_id", previousResponseID != ""), + zap.String("previous_response_id_kind", previousResponseIDKind), + ) + setOpsRequestContext(c, reqModel, true, firstMessage) + + var currentUserRelease func() + var currentAccountRelease func() + releaseTurnSlots := func() { + if currentAccountRelease != nil { + currentAccountRelease() + currentAccountRelease = nil + } + if currentUserRelease != nil { + currentUserRelease() + currentUserRelease = nil + } + } + // 必须尽早注册,确保任何 early return 都能释放已获取的并发槽位。 + defer releaseTurnSlots() + + userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency) + if err != nil { + reqLog.Warn("openai.websocket_user_slot_acquire_failed", zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot") + return + } + if !userAcquired { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later") + return + } + currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc) + + subscription, _ := middleware2.GetSubscriptionFromContext(c) + if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("openai.websocket_billing_eligibility_check_failed", zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "billing check failed") + return + } + + sessionHash := h.gatewayService.GenerateSessionHashWithFallback( + c, + firstMessage, + openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID), + ) + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + ctx, + apiKey.GroupID, + previousResponseID, + sessionHash, + reqModel, + nil, + service.OpenAIUpstreamTransportResponsesWebsocketV2, + ) + if err != nil { + reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account") + return + } + if selection == nil || selection.Account == nil { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account") + return + } + + account := selection.Account + accountMaxConcurrency := account.Concurrency + if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 { + accountMaxConcurrency = selection.WaitPlan.MaxConcurrency + } + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later") + return + } + fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot( + ctx, + account.ID, + selection.WaitPlan.MaxConcurrency, + ) + if err != nil { + reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot") + return + } + if !fastAcquired { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later") + return + } + accountReleaseFunc = fastReleaseFunc + } + currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc) + if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil { + reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + + token, _, err := h.gatewayService.GetAccessToken(ctx, account) + if err != nil { + reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token") + return + } + + reqLog.Debug("openai.websocket_account_selected", + zap.Int64("account_id", account.ID), + zap.String("account_name", account.Name), + zap.String("schedule_layer", scheduleDecision.Layer), + zap.Int("candidate_count", scheduleDecision.CandidateCount), + ) + + hooks := &service.OpenAIWSIngressHooks{ + BeforeTurn: func(turn int) error { + if turn == 1 { + return nil + } + // 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。 + releaseTurnSlots() + // 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。 + userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency) + if err != nil { + return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err) + } + if !userAcquired { + return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil) + } + accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency) + if err != nil { + if userReleaseFunc != nil { + userReleaseFunc() + } + return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err) + } + if !accountAcquired { + if userReleaseFunc != nil { + userReleaseFunc() + } + return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil) + } + currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc) + currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc) + return nil + }, + AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) { + releaseTurnSlots() + if turnErr != nil || result == nil { + return + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) + h.submitUsageRecordTask(func(taskCtx context.Context) { + if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + APIKeyService: h.apiKeyService, + }); err != nil { + reqLog.Error("openai.websocket_record_usage_failed", + zap.Int64("account_id", account.ID), + zap.String("request_id", result.RequestID), + zap.Error(err), + ) + } + }) + }, + } + + if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + closeStatus, closeReason := summarizeWSCloseErrorForLog(err) + reqLog.Warn("openai.websocket_proxy_failed", + zap.Int64("account_id", account.ID), + zap.Error(err), + zap.String("close_status", closeStatus), + zap.String("close_reason", closeReason), + ) + var closeErr *service.OpenAIWSClientCloseError + if errors.As(err, &closeErr) { + closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason()) + return + } + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed") + return + } + reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID)) +} + +func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) { + recovered := recover() + if recovered == nil { + return + } + + started := false + if streamStarted != nil { + started = *streamStarted + } + wroteFallback := h.ensureForwardErrorResponse(c, started) + requestLogger(c, "handler.openai_gateway.responses").Error( + "openai.responses_panic_recovered", + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Any("panic", recovered), + zap.ByteString("stack", debug.Stack()), + ) +} + +func (h *OpenAIGatewayHandler) ensureResponsesDependencies(c *gin.Context, reqLog *zap.Logger) bool { + missing := h.missingResponsesDependencies() + if len(missing) == 0 { + return true + } + + if reqLog == nil { + reqLog = requestLogger(c, "handler.openai_gateway.responses") + } + reqLog.Error("openai.handler_dependencies_missing", zap.Strings("missing_dependencies", missing)) + + if c != nil && c.Writer != nil && !c.Writer.Written() { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": gin.H{ + "type": "api_error", + "message": "Service temporarily unavailable", + }, + }) + } + return false +} + +func (h *OpenAIGatewayHandler) missingResponsesDependencies() []string { + missing := make([]string, 0, 5) + if h == nil { + return append(missing, "handler") + } + if h.gatewayService == nil { + missing = append(missing, "gatewayService") + } + if h.billingCacheService == nil { + missing = append(missing, "billingCacheService") + } + if h.apiKeyService == nil { + missing = append(missing, "apiKeyService") + } + if h.concurrencyHelper == nil || h.concurrencyHelper.concurrencyService == nil { + missing = append(missing, "concurrencyHelper") + } + return missing +} + +func getContextInt64(c *gin.Context, key string) (int64, bool) { + if c == nil || key == "" { + return 0, false + } + v, ok := c.Get(key) + if !ok { + return 0, false + } + switch t := v.(type) { + case int64: + return t, true + case int: + return int64(t), true + case int32: + return int64(t), true + case float64: + return int64(t), true + default: + return 0, false + } +} + +func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { + if task == nil { + return + } + if h.usageRecordWorkerPool != nil { + h.usageRecordWorkerPool.Submit(task) + return + } + // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + defer func() { + if recovered := recover(); recovered != nil { + logger.L().With( + zap.String("component", "handler.openai_gateway.responses"), + zap.Any("panic", recovered), + ).Error("openai.usage_record_task_panic_recovered") + } + }() + task(ctx) +} + // handleConcurrencyError handles concurrency-related errors with proper 429 response func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", @@ -397,8 +971,8 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status // Stream already started, send error as SSE event then close flusher, ok := c.Writer.(http.Flusher) if ok { - // Send error event in OpenAI SSE format - errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message) + // SSE 错误事件固定 schema,使用 Quote 直拼可避免额外 Marshal 分配。 + errorEvent := "event: error\ndata: " + `{"error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n" if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { _ = c.Error(err) } @@ -411,6 +985,25 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status h.errorResponse(c, status, errType, message) } +// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。 +func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool { + if c == nil || c.Writer == nil || c.Writer.Written() { + return false + } + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted) + return true +} + +func shouldLogOpenAIForwardFailureAsWarn(c *gin.Context, wroteFallback bool) bool { + if wroteFallback { + return false + } + if c == nil || c.Writer == nil { + return false + } + return c.Writer.Written() +} + // errorResponse returns OpenAI API format error response func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { c.JSON(status, gin.H{ @@ -420,3 +1013,61 @@ func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType }, }) } + +func setOpenAIClientTransportHTTP(c *gin.Context) { + service.SetOpenAIClientTransport(c, service.OpenAIClientTransportHTTP) +} + +func setOpenAIClientTransportWS(c *gin.Context) { + service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS) +} + +func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string { + gid := int64(0) + if groupID != nil { + gid = *groupID + } + return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID) +} + +func isOpenAIWSUpgradeRequest(r *http.Request) bool { + if r == nil { + return false + } + if !strings.EqualFold(strings.TrimSpace(r.Header.Get("Upgrade")), "websocket") { + return false + } + return strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Connection"))), "upgrade") +} + +func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason string) { + if conn == nil { + return + } + reason = strings.TrimSpace(reason) + if len(reason) > 120 { + reason = reason[:120] + } + _ = conn.Close(status, reason) + _ = conn.CloseNow() +} + +func summarizeWSCloseErrorForLog(err error) (string, string) { + if err == nil { + return "-", "-" + } + statusCode := coderws.CloseStatus(err) + if statusCode == -1 { + return "-", "-" + } + closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String()) + closeReason := "-" + var closeErr coderws.CloseError + if errors.As(err, &closeErr) { + reason := strings.TrimSpace(closeErr.Reason) + if reason != "" { + closeReason = reason + } + } + return closeStatus, closeReason +} diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go new file mode 100644 index 00000000..a26b3a0c --- /dev/null +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -0,0 +1,677 @@ +package handler + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +func TestOpenAIHandleStreamingAwareError_JSONEscaping(t *testing.T) { + tests := []struct { + name string + errType string + message string + }{ + { + name: "包含双引号的消息", + errType: "server_error", + message: `upstream returned "invalid" response`, + }, + { + name: "包含反斜杠的消息", + errType: "server_error", + message: `path C:\Users\test\file.txt not found`, + }, + { + name: "包含双引号和反斜杠的消息", + errType: "upstream_error", + message: `error parsing "key\value": unexpected token`, + }, + { + name: "包含换行符的消息", + errType: "server_error", + message: "line1\nline2\ttab", + }, + { + name: "普通消息", + errType: "upstream_error", + message: "Upstream service temporarily unavailable", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &OpenAIGatewayHandler{} + h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true) + + body := w.Body.String() + + // 验证 SSE 格式:event: error\ndata: {JSON}\n\n + assert.True(t, strings.HasPrefix(body, "event: error\n"), "应以 'event: error\\n' 开头") + assert.True(t, strings.HasSuffix(body, "\n\n"), "应以 '\\n\\n' 结尾") + + // 提取 data 部分 + lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n") + require.Len(t, lines, 2, "应有 event 行和 data 行") + dataLine := lines[1] + require.True(t, strings.HasPrefix(dataLine, "data: "), "第二行应以 'data: ' 开头") + jsonStr := strings.TrimPrefix(dataLine, "data: ") + + // 验证 JSON 合法性 + var parsed map[string]any + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err, "JSON 应能被成功解析,原始 JSON: %s", jsonStr) + + // 验证结构 + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok, "应包含 error 对象") + assert.Equal(t, tt.errType, errorObj["type"]) + assert.Equal(t, tt.message, errorObj["message"]) + }) + } +} + +func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &OpenAIGatewayHandler{} + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "test error", false) + + // 非流式应返回 JSON 响应 + assert.Equal(t, http.StatusBadGateway, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errorObj["type"]) + assert.Equal(t, "test error", errorObj["message"]) +} + +func TestReadRequestBodyWithPrealloc(t *testing.T) { + payload := `{"model":"gpt-5","input":"hello"}` + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(payload)) + req.ContentLength = int64(len(payload)) + + body, err := pkghttputil.ReadRequestBodyWithPrealloc(req) + require.NoError(t, err) + require.Equal(t, payload, string(body)) +} + +func TestReadRequestBodyWithPrealloc_MaxBytesError(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(strings.Repeat("x", 8))) + req.Body = http.MaxBytesReader(rec, req.Body, 4) + + _, err := pkghttputil.ReadRequestBodyWithPrealloc(req) + require.Error(t, err) + var maxErr *http.MaxBytesError + require.ErrorAs(t, err, &maxErr) +} + +func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &OpenAIGatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.True(t, wrote) + require.Equal(t, http.StatusBadGateway, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errorObj["type"]) + assert.Equal(t, "Upstream request failed", errorObj["message"]) +} + +func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.String(http.StatusTeapot, "already written") + + h := &OpenAIGatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.False(t, wrote) + require.Equal(t, http.StatusTeapot, w.Code) + assert.Equal(t, "already written", w.Body.String()) +} + +func TestShouldLogOpenAIForwardFailureAsWarn(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("fallback_written_should_not_downgrade", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, true)) + }) + + t.Run("context_nil_should_not_downgrade", func(t *testing.T) { + require.False(t, shouldLogOpenAIForwardFailureAsWarn(nil, false)) + }) + + t.Run("response_not_written_should_not_downgrade", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, false)) + }) + + t.Run("response_already_written_should_downgrade", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.String(http.StatusForbidden, "already written") + require.True(t, shouldLogOpenAIForwardFailureAsWarn(c, false)) + }) +} + +func TestOpenAIRecoverResponsesPanic_WritesFallbackResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + h := &OpenAIGatewayHandler{} + streamStarted := false + require.NotPanics(t, func() { + func() { + defer h.recoverResponsesPanic(c, &streamStarted) + panic("test panic") + }() + }) + + require.Equal(t, http.StatusBadGateway, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errorObj["type"]) + assert.Equal(t, "Upstream request failed", errorObj["message"]) +} + +func TestOpenAIRecoverResponsesPanic_NoPanicNoWrite(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + h := &OpenAIGatewayHandler{} + streamStarted := false + require.NotPanics(t, func() { + func() { + defer h.recoverResponsesPanic(c, &streamStarted) + }() + }) + + require.False(t, c.Writer.Written()) + assert.Equal(t, "", w.Body.String()) +} + +func TestOpenAIRecoverResponsesPanic_DoesNotOverrideWrittenResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + c.String(http.StatusTeapot, "already written") + + h := &OpenAIGatewayHandler{} + streamStarted := false + require.NotPanics(t, func() { + func() { + defer h.recoverResponsesPanic(c, &streamStarted) + panic("test panic") + }() + }) + + require.Equal(t, http.StatusTeapot, w.Code) + assert.Equal(t, "already written", w.Body.String()) +} + +func TestOpenAIMissingResponsesDependencies(t *testing.T) { + t.Run("nil_handler", func(t *testing.T) { + var h *OpenAIGatewayHandler + require.Equal(t, []string{"handler"}, h.missingResponsesDependencies()) + }) + + t.Run("all_dependencies_missing", func(t *testing.T) { + h := &OpenAIGatewayHandler{} + require.Equal(t, + []string{"gatewayService", "billingCacheService", "apiKeyService", "concurrencyHelper"}, + h.missingResponsesDependencies(), + ) + }) + + t.Run("all_dependencies_present", func(t *testing.T) { + h := &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: &service.BillingCacheService{}, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: &ConcurrencyHelper{ + concurrencyService: &service.ConcurrencyService{}, + }, + } + require.Empty(t, h.missingResponsesDependencies()) + }) +} + +func TestOpenAIEnsureResponsesDependencies(t *testing.T) { + t.Run("missing_dependencies_returns_503", func(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + h := &OpenAIGatewayHandler{} + ok := h.ensureResponsesDependencies(c, nil) + + require.False(t, ok) + require.Equal(t, http.StatusServiceUnavailable, w.Code) + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + errorObj, exists := parsed["error"].(map[string]any) + require.True(t, exists) + assert.Equal(t, "api_error", errorObj["type"]) + assert.Equal(t, "Service temporarily unavailable", errorObj["message"]) + }) + + t.Run("already_written_response_not_overridden", func(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + c.String(http.StatusTeapot, "already written") + + h := &OpenAIGatewayHandler{} + ok := h.ensureResponsesDependencies(c, nil) + + require.False(t, ok) + require.Equal(t, http.StatusTeapot, w.Code) + assert.Equal(t, "already written", w.Body.String()) + }) + + t.Run("dependencies_ready_returns_true_and_no_write", func(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + h := &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: &service.BillingCacheService{}, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: &ConcurrencyHelper{ + concurrencyService: &service.ConcurrencyService{}, + }, + } + ok := h.ensureResponsesDependencies(c, nil) + + require.True(t, ok) + require.False(t, c.Writer.Written()) + assert.Equal(t, "", w.Body.String()) + }) +} + +func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{"model":"gpt-5","stream":false}`)) + c.Request.Header.Set("Content-Type", "application/json") + + groupID := int64(2) + c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{ + ID: 10, + GroupID: &groupID, + }) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ + UserID: 1, + Concurrency: 1, + }) + + // 故意使用未初始化依赖,验证快速失败而不是崩溃。 + h := &OpenAIGatewayHandler{} + require.NotPanics(t, func() { + h.Responses(c) + }) + + require.Equal(t, http.StatusServiceUnavailable, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "api_error", errorObj["type"]) + assert.Equal(t, "Service temporarily unavailable", errorObj["message"]) +} + +func TestOpenAIResponses_SetsClientTransportHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(`{"model":"gpt-5"}`)) + c.Request.Header.Set("Content-Type", "application/json") + + h := &OpenAIGatewayHandler{} + h.Responses(c) + + require.Equal(t, http.StatusUnauthorized, w.Code) + require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c)) +} + +func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader( + `{"model":"gpt-5.1","stream":false,"previous_response_id":"msg_123456","input":[{"type":"input_text","text":"hello"}]}`, + )) + c.Request.Header.Set("Content-Type", "application/json") + + groupID := int64(2) + c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{ + ID: 101, + GroupID: &groupID, + User: &service.User{ID: 1}, + }) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ + UserID: 1, + Concurrency: 1, + }) + + h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil) + h.Responses(c) + + require.Equal(t, http.StatusBadRequest, w.Code) + require.Contains(t, w.Body.String(), "previous_response_id must be a response.id") +} + +func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil) + c.Request.Header.Set("Upgrade", "websocket") + c.Request.Header.Set("Connection", "Upgrade") + + h := &OpenAIGatewayHandler{} + h.ResponsesWebSocket(c) + + require.Equal(t, http.StatusUnauthorized, w.Code) + require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c)) +} + +func TestOpenAIResponsesWebSocket_InvalidUpgradeDoesNotSetTransport(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil) + + h := &OpenAIGatewayHandler{} + h.ResponsesWebSocket(c) + + require.Equal(t, http.StatusUpgradeRequired, w.Code) + require.Equal(t, service.OpenAIClientTransportUnknown, service.GetOpenAIClientTransport(c)) +} + +func TestOpenAIResponsesWebSocket_RejectsMessageIDAsPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil) + wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1}) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte( + `{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_abc123"}`, + )) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, _, err = clientConn.Read(readCtx) + cancelRead() + require.Error(t, err) + var closeErr coderws.CloseError + require.ErrorAs(t, err, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code) + require.Contains(t, strings.ToLower(closeErr.Reason), "previous_response_id") +} + +func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailure(t *testing.T) { + gin.SetMode(gin.TestMode) + + cache := &concurrencyCacheMock{ + acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + return false, errors.New("user slot unavailable") + }, + } + h := newOpenAIHandlerForPreviousResponseIDValidation(t, cache) + wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1}) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte( + `{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_123"}`, + )) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, _, err = clientConn.Read(readCtx) + cancelRead() + require.Error(t, err) + var closeErr coderws.CloseError + require.ErrorAs(t, err, &closeErr) + require.Equal(t, coderws.StatusInternalError, closeErr.Code) + require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot") +} + +func TestSetOpenAIClientTransportHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + setOpenAIClientTransportHTTP(c) + require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c)) +} + +func TestSetOpenAIClientTransportWS(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + setOpenAIClientTransportWS(c) + require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c)) +} + +// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性 +func TestOpenAIHandler_GjsonExtraction(t *testing.T) { + tests := []struct { + name string + body string + wantModel string + wantStream bool + }{ + {"正常提取", `{"model":"gpt-4","stream":true,"input":"hello"}`, "gpt-4", true}, + {"stream false", `{"model":"gpt-4","stream":false}`, "gpt-4", false}, + {"无 stream 字段", `{"model":"gpt-4"}`, "gpt-4", false}, + {"model 缺失", `{"stream":true}`, "", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body := []byte(tt.body) + modelResult := gjson.GetBytes(body, "model") + model := "" + if modelResult.Type == gjson.String { + model = modelResult.String() + } + stream := gjson.GetBytes(body, "stream").Bool() + require.Equal(t, tt.wantModel, model) + require.Equal(t, tt.wantStream, stream) + }) + } +} + +// TestOpenAIHandler_GjsonValidation 验证修复后的 JSON 合法性和类型校验 +func TestOpenAIHandler_GjsonValidation(t *testing.T) { + // 非法 JSON 被 gjson.ValidBytes 拦截 + require.False(t, gjson.ValidBytes([]byte(`{invalid json`))) + + // model 为数字 → 类型不是 gjson.String,应被拒绝 + body := []byte(`{"model":123}`) + modelResult := gjson.GetBytes(body, "model") + require.True(t, modelResult.Exists()) + require.NotEqual(t, gjson.String, modelResult.Type) + + // model 为 null → 类型不是 gjson.String,应被拒绝 + body2 := []byte(`{"model":null}`) + modelResult2 := gjson.GetBytes(body2, "model") + require.True(t, modelResult2.Exists()) + require.NotEqual(t, gjson.String, modelResult2.Type) + + // stream 为 string → 类型既不是 True 也不是 False,应被拒绝 + body3 := []byte(`{"model":"gpt-4","stream":"true"}`) + streamResult := gjson.GetBytes(body3, "stream") + require.True(t, streamResult.Exists()) + require.NotEqual(t, gjson.True, streamResult.Type) + require.NotEqual(t, gjson.False, streamResult.Type) + + // stream 为 int → 同上 + body4 := []byte(`{"model":"gpt-4","stream":1}`) + streamResult2 := gjson.GetBytes(body4, "stream") + require.True(t, streamResult2.Exists()) + require.NotEqual(t, gjson.True, streamResult2.Type) + require.NotEqual(t, gjson.False, streamResult2.Type) +} + +// TestOpenAIHandler_InstructionsInjection 验证 instructions 的 gjson/sjson 注入逻辑 +func TestOpenAIHandler_InstructionsInjection(t *testing.T) { + // 测试 1:无 instructions → 注入 + body := []byte(`{"model":"gpt-4"}`) + existing := gjson.GetBytes(body, "instructions").String() + require.Empty(t, existing) + newBody, err := sjson.SetBytes(body, "instructions", "test instruction") + require.NoError(t, err) + require.Equal(t, "test instruction", gjson.GetBytes(newBody, "instructions").String()) + + // 测试 2:已有 instructions → 不覆盖 + body2 := []byte(`{"model":"gpt-4","instructions":"existing"}`) + existing2 := gjson.GetBytes(body2, "instructions").String() + require.Equal(t, "existing", existing2) + + // 测试 3:空白 instructions → 注入 + body3 := []byte(`{"model":"gpt-4","instructions":" "}`) + existing3 := strings.TrimSpace(gjson.GetBytes(body3, "instructions").String()) + require.Empty(t, existing3) + + // 测试 4:sjson.SetBytes 返回错误时不应 panic + // 正常 JSON 不会产生 sjson 错误,验证返回值被正确处理 + validBody := []byte(`{"model":"gpt-4"}`) + result, setErr := sjson.SetBytes(validBody, "instructions", "hello") + require.NoError(t, setErr) + require.True(t, gjson.ValidBytes(result)) +} + +func newOpenAIHandlerForPreviousResponseIDValidation(t *testing.T, cache *concurrencyCacheMock) *OpenAIGatewayHandler { + t.Helper() + if cache == nil { + cache = &concurrencyCacheMock{ + acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + } + } + return &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: &service.BillingCacheService{}, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second), + } +} + +func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject middleware.AuthSubject) *httptest.Server { + t.Helper() + groupID := int64(2) + apiKey := &service.APIKey{ + ID: 101, + GroupID: &groupID, + User: &service.User{ID: subject.UserID}, + } + router := gin.New() + router.Use(func(c *gin.Context) { + c.Set(string(middleware.ContextKeyAPIKey), apiKey) + c.Set(string(middleware.ContextKeyUser), subject) + c.Next() + }) + router.GET("/openai/v1/responses", h.ResponsesWebSocket) + return httptest.NewServer(router) +} diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go index cb62ceae..2f53d655 100644 --- a/backend/internal/handler/ops_error_logger.go +++ b/backend/internal/handler/ops_error_logger.go @@ -41,9 +41,8 @@ const ( ) type opsErrorLogJob struct { - ops *service.OpsService - entry *service.OpsInsertErrorLogInput - requestBody []byte + ops *service.OpsService + entry *service.OpsInsertErrorLogInput } var ( @@ -58,6 +57,7 @@ var ( opsErrorLogEnqueued atomic.Int64 opsErrorLogDropped atomic.Int64 opsErrorLogProcessed atomic.Int64 + opsErrorLogSanitized atomic.Int64 opsErrorLogLastDropLogAt atomic.Int64 @@ -94,7 +94,7 @@ func startOpsErrorLogWorkers() { } }() ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout) - _ = job.ops.RecordError(ctx, job.entry, job.requestBody) + _ = job.ops.RecordError(ctx, job.entry, nil) cancel() opsErrorLogProcessed.Add(1) }() @@ -103,7 +103,7 @@ func startOpsErrorLogWorkers() { } } -func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput, requestBody []byte) { +func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) { if ops == nil || entry == nil { return } @@ -129,7 +129,7 @@ func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLo } select { - case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry, requestBody: requestBody}: + case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry}: opsErrorLogQueueLen.Add(1) opsErrorLogEnqueued.Add(1) default: @@ -205,6 +205,10 @@ func OpsErrorLogProcessedTotal() int64 { return opsErrorLogProcessed.Load() } +func OpsErrorLogSanitizedTotal() int64 { + return opsErrorLogSanitized.Load() +} + func maybeLogOpsErrorLogDrop() { now := time.Now().Unix() @@ -222,12 +226,13 @@ func maybeLogOpsErrorLogDrop() { queueCap := OpsErrorLogQueueCapacity() log.Printf( - "[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d)", + "[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d sanitized_total=%d)", queued, queueCap, opsErrorLogEnqueued.Load(), opsErrorLogDropped.Load(), opsErrorLogProcessed.Load(), + opsErrorLogSanitized.Load(), ) } @@ -255,18 +260,49 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody if c == nil { return } + model = strings.TrimSpace(model) c.Set(opsModelKey, model) c.Set(opsStreamKey, stream) if len(requestBody) > 0 { c.Set(opsRequestBodyKey, requestBody) } + if c.Request != nil && model != "" { + ctx := context.WithValue(c.Request.Context(), ctxkey.Model, model) + c.Request = c.Request.WithContext(ctx) + } } -func setOpsSelectedAccount(c *gin.Context, accountID int64) { +func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) { + if c == nil || entry == nil { + return + } + v, ok := c.Get(opsRequestBodyKey) + if !ok { + return + } + raw, ok := v.([]byte) + if !ok || len(raw) == 0 { + return + } + entry.RequestBodyJSON, entry.RequestBodyTruncated, entry.RequestBodyBytes = service.PrepareOpsRequestBodyForQueue(raw) + opsErrorLogSanitized.Add(1) +} + +func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string) { if c == nil || accountID <= 0 { return } c.Set(opsAccountIDKey, accountID) + if c.Request != nil { + ctx := context.WithValue(c.Request.Context(), ctxkey.AccountID, accountID) + if len(platform) > 0 { + p := strings.TrimSpace(platform[0]) + if p != "" { + ctx = context.WithValue(ctx, ctxkey.Platform, p) + } + } + c.Request = c.Request.WithContext(ctx) + } } type opsCaptureWriter struct { @@ -275,6 +311,35 @@ type opsCaptureWriter struct { buf bytes.Buffer } +const opsCaptureWriterLimit = 64 * 1024 + +var opsCaptureWriterPool = sync.Pool{ + New: func() any { + return &opsCaptureWriter{limit: opsCaptureWriterLimit} + }, +} + +func acquireOpsCaptureWriter(rw gin.ResponseWriter) *opsCaptureWriter { + w, ok := opsCaptureWriterPool.Get().(*opsCaptureWriter) + if !ok || w == nil { + w = &opsCaptureWriter{} + } + w.ResponseWriter = rw + w.limit = opsCaptureWriterLimit + w.buf.Reset() + return w +} + +func releaseOpsCaptureWriter(w *opsCaptureWriter) { + if w == nil { + return + } + w.ResponseWriter = nil + w.limit = opsCaptureWriterLimit + w.buf.Reset() + opsCaptureWriterPool.Put(w) +} + func (w *opsCaptureWriter) Write(b []byte) (int, error) { if w.Status() >= 400 && w.limit > 0 && w.buf.Len() < w.limit { remaining := w.limit - w.buf.Len() @@ -306,7 +371,16 @@ func (w *opsCaptureWriter) WriteString(s string) (int, error) { // - Streaming errors after the response has started (SSE) may still need explicit logging. func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { return func(c *gin.Context) { - w := &opsCaptureWriter{ResponseWriter: c.Writer, limit: 64 * 1024} + originalWriter := c.Writer + w := acquireOpsCaptureWriter(originalWriter) + defer func() { + // Restore the original writer before returning so outer middlewares + // don't observe a pooled wrapper that has been released. + if c.Writer == w { + c.Writer = originalWriter + } + releaseOpsCaptureWriter(w) + }() c.Writer = w c.Next() @@ -507,6 +581,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { RetryCount: 0, CreatedAt: time.Now(), } + applyOpsLatencyFieldsFromContext(c, entry) if apiKey != nil { entry.APIKeyID = &apiKey.ID @@ -528,14 +603,9 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { entry.ClientIP = &clientIP } - var requestBody []byte - if v, ok := c.Get(opsRequestBodyKey); ok { - if b, ok := v.([]byte); ok && len(b) > 0 { - requestBody = b - } - } // Store request headers/body only when an upstream error occurred to keep overhead minimal. entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c) + attachOpsRequestBodyToEntry(c, entry) // Skip logging if a passthrough rule with skip_monitoring=true matched. if v, ok := c.Get(service.OpsSkipPassthroughKey); ok { @@ -544,7 +614,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { } } - enqueueOpsErrorLog(ops, entry, requestBody) + enqueueOpsErrorLog(ops, entry) return } @@ -592,8 +662,10 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { requestID = c.Writer.Header().Get("x-request-id") } - phase := classifyOpsPhase(parsed.ErrorType, parsed.Message, parsed.Code) - isBusinessLimited := classifyOpsIsBusinessLimited(parsed.ErrorType, phase, parsed.Code, status, parsed.Message) + normalizedType := normalizeOpsErrorType(parsed.ErrorType, parsed.Code) + + phase := classifyOpsPhase(normalizedType, parsed.Message, parsed.Code) + isBusinessLimited := classifyOpsIsBusinessLimited(normalizedType, phase, parsed.Code, status, parsed.Message) errorOwner := classifyOpsErrorOwner(phase, parsed.Message) errorSource := classifyOpsErrorSource(phase, parsed.Message) @@ -615,8 +687,8 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { UserAgent: c.GetHeader("User-Agent"), ErrorPhase: phase, - ErrorType: normalizeOpsErrorType(parsed.ErrorType, parsed.Code), - Severity: classifyOpsSeverity(parsed.ErrorType, status), + ErrorType: normalizedType, + Severity: classifyOpsSeverity(normalizedType, status), StatusCode: status, IsBusinessLimited: isBusinessLimited, IsCountTokens: isCountTokensRequest(c), @@ -628,10 +700,11 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { ErrorSource: errorSource, ErrorOwner: errorOwner, - IsRetryable: classifyOpsIsRetryable(parsed.ErrorType, status), + IsRetryable: classifyOpsIsRetryable(normalizedType, status), RetryCount: 0, CreatedAt: time.Now(), } + applyOpsLatencyFieldsFromContext(c, entry) // Capture upstream error context set by gateway services (if present). // This does NOT affect the client response; it enriches Ops troubleshooting data. @@ -707,17 +780,12 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { entry.ClientIP = &clientIP } - var requestBody []byte - if v, ok := c.Get(opsRequestBodyKey); ok { - if b, ok := v.([]byte); ok && len(b) > 0 { - requestBody = b - } - } // Persist only a minimal, whitelisted set of request headers to improve retry fidelity. // Do NOT store Authorization/Cookie/etc. entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c) + attachOpsRequestBodyToEntry(c, entry) - enqueueOpsErrorLog(ops, entry, requestBody) + enqueueOpsErrorLog(ops, entry) } } @@ -760,6 +828,44 @@ func extractOpsRetryRequestHeaders(c *gin.Context) *string { return &s } +func applyOpsLatencyFieldsFromContext(c *gin.Context, entry *service.OpsInsertErrorLogInput) { + if c == nil || entry == nil { + return + } + entry.AuthLatencyMs = getContextLatencyMs(c, service.OpsAuthLatencyMsKey) + entry.RoutingLatencyMs = getContextLatencyMs(c, service.OpsRoutingLatencyMsKey) + entry.UpstreamLatencyMs = getContextLatencyMs(c, service.OpsUpstreamLatencyMsKey) + entry.ResponseLatencyMs = getContextLatencyMs(c, service.OpsResponseLatencyMsKey) + entry.TimeToFirstTokenMs = getContextLatencyMs(c, service.OpsTimeToFirstTokenMsKey) +} + +func getContextLatencyMs(c *gin.Context, key string) *int64 { + if c == nil || strings.TrimSpace(key) == "" { + return nil + } + v, ok := c.Get(key) + if !ok { + return nil + } + var ms int64 + switch t := v.(type) { + case int: + ms = int64(t) + case int32: + ms = int64(t) + case int64: + ms = t + case float64: + ms = int64(t) + default: + return nil + } + if ms < 0 { + return nil + } + return &ms +} + type parsedOpsError struct { ErrorType string Message string @@ -835,8 +941,29 @@ func guessPlatformFromPath(path string) string { } } +// isKnownOpsErrorType returns true if t is a recognized error type used by the +// ops classification pipeline. Upstream proxies sometimes return garbage values +// (e.g. the Go-serialized literal "") which would pollute phase/severity +// classification if accepted blindly. +func isKnownOpsErrorType(t string) bool { + switch t { + case "invalid_request_error", + "authentication_error", + "rate_limit_error", + "billing_error", + "subscription_error", + "upstream_error", + "overloaded_error", + "api_error", + "not_found_error", + "forbidden_error": + return true + } + return false +} + func normalizeOpsErrorType(errType string, code string) string { - if errType != "" { + if errType != "" && isKnownOpsErrorType(errType) { return errType } switch strings.TrimSpace(code) { diff --git a/backend/internal/handler/ops_error_logger_test.go b/backend/internal/handler/ops_error_logger_test.go new file mode 100644 index 00000000..679dd4ce --- /dev/null +++ b/backend/internal/handler/ops_error_logger_test.go @@ -0,0 +1,276 @@ +package handler + +import ( + "net/http" + "net/http/httptest" + "sync" + "testing" + + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func resetOpsErrorLoggerStateForTest(t *testing.T) { + t.Helper() + + opsErrorLogMu.Lock() + ch := opsErrorLogQueue + opsErrorLogQueue = nil + opsErrorLogStopping = true + opsErrorLogMu.Unlock() + + if ch != nil { + close(ch) + } + opsErrorLogWorkersWg.Wait() + + opsErrorLogOnce = sync.Once{} + opsErrorLogStopOnce = sync.Once{} + opsErrorLogWorkersWg = sync.WaitGroup{} + opsErrorLogMu = sync.RWMutex{} + opsErrorLogStopping = false + + opsErrorLogQueueLen.Store(0) + opsErrorLogEnqueued.Store(0) + opsErrorLogDropped.Store(0) + opsErrorLogProcessed.Store(0) + opsErrorLogSanitized.Store(0) + opsErrorLogLastDropLogAt.Store(0) + + opsErrorLogShutdownCh = make(chan struct{}) + opsErrorLogShutdownOnce = sync.Once{} + opsErrorLogDrained.Store(false) +} + +func TestAttachOpsRequestBodyToEntry_SanitizeAndTrim(t *testing.T) { + resetOpsErrorLoggerStateForTest(t) + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + raw := []byte(`{"access_token":"secret-token","messages":[{"role":"user","content":"hello"}]}`) + setOpsRequestContext(c, "claude-3", false, raw) + + entry := &service.OpsInsertErrorLogInput{} + attachOpsRequestBodyToEntry(c, entry) + + require.NotNil(t, entry.RequestBodyBytes) + require.Equal(t, len(raw), *entry.RequestBodyBytes) + require.NotNil(t, entry.RequestBodyJSON) + require.NotContains(t, *entry.RequestBodyJSON, "secret-token") + require.Contains(t, *entry.RequestBodyJSON, "[REDACTED]") + require.Equal(t, int64(1), OpsErrorLogSanitizedTotal()) +} + +func TestAttachOpsRequestBodyToEntry_InvalidJSONKeepsSize(t *testing.T) { + resetOpsErrorLoggerStateForTest(t) + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + raw := []byte("not-json") + setOpsRequestContext(c, "claude-3", false, raw) + + entry := &service.OpsInsertErrorLogInput{} + attachOpsRequestBodyToEntry(c, entry) + + require.Nil(t, entry.RequestBodyJSON) + require.NotNil(t, entry.RequestBodyBytes) + require.Equal(t, len(raw), *entry.RequestBodyBytes) + require.False(t, entry.RequestBodyTruncated) + require.Equal(t, int64(1), OpsErrorLogSanitizedTotal()) +} + +func TestEnqueueOpsErrorLog_QueueFullDrop(t *testing.T) { + resetOpsErrorLoggerStateForTest(t) + + // 禁止 enqueueOpsErrorLog 触发 workers,使用测试队列验证满队列降级。 + opsErrorLogOnce.Do(func() {}) + + opsErrorLogMu.Lock() + opsErrorLogQueue = make(chan opsErrorLogJob, 1) + opsErrorLogMu.Unlock() + + ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"} + + enqueueOpsErrorLog(ops, entry) + enqueueOpsErrorLog(ops, entry) + + require.Equal(t, int64(1), OpsErrorLogEnqueuedTotal()) + require.Equal(t, int64(1), OpsErrorLogDroppedTotal()) + require.Equal(t, int64(1), OpsErrorLogQueueLength()) +} + +func TestAttachOpsRequestBodyToEntry_EarlyReturnBranches(t *testing.T) { + resetOpsErrorLoggerStateForTest(t) + gin.SetMode(gin.TestMode) + + entry := &service.OpsInsertErrorLogInput{} + attachOpsRequestBodyToEntry(nil, entry) + attachOpsRequestBodyToEntry(&gin.Context{}, nil) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + // 无请求体 key + attachOpsRequestBodyToEntry(c, entry) + require.Nil(t, entry.RequestBodyJSON) + require.Nil(t, entry.RequestBodyBytes) + require.False(t, entry.RequestBodyTruncated) + + // 错误类型 + c.Set(opsRequestBodyKey, "not-bytes") + attachOpsRequestBodyToEntry(c, entry) + require.Nil(t, entry.RequestBodyJSON) + require.Nil(t, entry.RequestBodyBytes) + + // 空 bytes + c.Set(opsRequestBodyKey, []byte{}) + attachOpsRequestBodyToEntry(c, entry) + require.Nil(t, entry.RequestBodyJSON) + require.Nil(t, entry.RequestBodyBytes) + + require.Equal(t, int64(0), OpsErrorLogSanitizedTotal()) +} + +func TestEnqueueOpsErrorLog_EarlyReturnBranches(t *testing.T) { + resetOpsErrorLoggerStateForTest(t) + + ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"} + + // nil 入参分支 + enqueueOpsErrorLog(nil, entry) + enqueueOpsErrorLog(ops, nil) + require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) + + // shutdown 分支 + close(opsErrorLogShutdownCh) + enqueueOpsErrorLog(ops, entry) + require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) + + // stopping 分支 + resetOpsErrorLoggerStateForTest(t) + opsErrorLogMu.Lock() + opsErrorLogStopping = true + opsErrorLogMu.Unlock() + enqueueOpsErrorLog(ops, entry) + require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) + + // queue nil 分支(防止启动 worker 干扰) + resetOpsErrorLoggerStateForTest(t) + opsErrorLogOnce.Do(func() {}) + opsErrorLogMu.Lock() + opsErrorLogQueue = nil + opsErrorLogMu.Unlock() + enqueueOpsErrorLog(ops, entry) + require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) +} + +func TestOpsCaptureWriterPool_ResetOnRelease(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/test", nil) + + writer := acquireOpsCaptureWriter(c.Writer) + require.NotNil(t, writer) + _, err := writer.buf.WriteString("temp-error-body") + require.NoError(t, err) + + releaseOpsCaptureWriter(writer) + + reused := acquireOpsCaptureWriter(c.Writer) + defer releaseOpsCaptureWriter(reused) + + require.Zero(t, reused.buf.Len(), "writer should be reset before reuse") +} + +func TestOpsErrorLoggerMiddleware_DoesNotBreakOuterMiddlewares(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(middleware2.Recovery()) + r.Use(middleware2.RequestLogger()) + r.Use(middleware2.Logger()) + r.GET("/v1/messages", OpsErrorLoggerMiddleware(nil), func(c *gin.Context) { + c.Status(http.StatusNoContent) + }) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/v1/messages", nil) + + require.NotPanics(t, func() { + r.ServeHTTP(rec, req) + }) + require.Equal(t, http.StatusNoContent, rec.Code) +} + +func TestIsKnownOpsErrorType(t *testing.T) { + known := []string{ + "invalid_request_error", + "authentication_error", + "rate_limit_error", + "billing_error", + "subscription_error", + "upstream_error", + "overloaded_error", + "api_error", + "not_found_error", + "forbidden_error", + } + for _, k := range known { + require.True(t, isKnownOpsErrorType(k), "expected known: %s", k) + } + + unknown := []string{"", "null", "", "random_error", "some_new_type", "\u003e"} + for _, u := range unknown { + require.False(t, isKnownOpsErrorType(u), "expected unknown: %q", u) + } +} + +func TestNormalizeOpsErrorType(t *testing.T) { + tests := []struct { + name string + errType string + code string + want string + }{ + // Known types pass through. + {"known invalid_request_error", "invalid_request_error", "", "invalid_request_error"}, + {"known rate_limit_error", "rate_limit_error", "", "rate_limit_error"}, + {"known upstream_error", "upstream_error", "", "upstream_error"}, + + // Unknown/garbage types are rejected and fall through to code-based or default. + {"nil literal from upstream", "", "", "api_error"}, + {"null string", "null", "", "api_error"}, + {"random string", "something_weird", "", "api_error"}, + + // Unknown type but known code still maps correctly. + {"nil with INSUFFICIENT_BALANCE code", "", "INSUFFICIENT_BALANCE", "billing_error"}, + {"nil with USAGE_LIMIT_EXCEEDED code", "", "USAGE_LIMIT_EXCEEDED", "subscription_error"}, + + // Empty type falls through to code-based mapping. + {"empty type with balance code", "", "INSUFFICIENT_BALANCE", "billing_error"}, + {"empty type with subscription code", "", "SUBSCRIPTION_NOT_FOUND", "subscription_error"}, + {"empty type no code", "", "", "api_error"}, + + // Known type overrides conflicting code-based mapping. + {"known type overrides conflicting code", "rate_limit_error", "INSUFFICIENT_BALANCE", "rate_limit_error"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := normalizeOpsErrorType(tt.errType, tt.code) + require.Equal(t, tt.want, got) + }) + } +} diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 2029f116..1188d55e 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -32,25 +32,28 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { } response.Success(c, dto.PublicSettings{ - RegistrationEnabled: settings.RegistrationEnabled, - EmailVerifyEnabled: settings.EmailVerifyEnabled, - PromoCodeEnabled: settings.PromoCodeEnabled, - PasswordResetEnabled: settings.PasswordResetEnabled, - InvitationCodeEnabled: settings.InvitationCodeEnabled, - TotpEnabled: settings.TotpEnabled, - TurnstileEnabled: settings.TurnstileEnabled, - TurnstileSiteKey: settings.TurnstileSiteKey, - SiteName: settings.SiteName, - SiteLogo: settings.SiteLogo, - SiteSubtitle: settings.SiteSubtitle, - APIBaseURL: settings.APIBaseURL, - ContactInfo: settings.ContactInfo, - DocURL: settings.DocURL, - HomeContent: settings.HomeContent, - HideCcsImportButton: settings.HideCcsImportButton, - PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, - PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, - LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, - Version: h.version, + RegistrationEnabled: settings.RegistrationEnabled, + EmailVerifyEnabled: settings.EmailVerifyEnabled, + RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, + PromoCodeEnabled: settings.PromoCodeEnabled, + PasswordResetEnabled: settings.PasswordResetEnabled, + InvitationCodeEnabled: settings.InvitationCodeEnabled, + TotpEnabled: settings.TotpEnabled, + TurnstileEnabled: settings.TurnstileEnabled, + TurnstileSiteKey: settings.TurnstileSiteKey, + SiteName: settings.SiteName, + SiteLogo: settings.SiteLogo, + SiteSubtitle: settings.SiteSubtitle, + APIBaseURL: settings.APIBaseURL, + ContactInfo: settings.ContactInfo, + DocURL: settings.DocURL, + HomeContent: settings.HomeContent, + HideCcsImportButton: settings.HideCcsImportButton, + PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, + PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, + CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), + LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + SoraClientEnabled: settings.SoraClientEnabled, + Version: h.version, }) } diff --git a/backend/internal/handler/sora_client_handler.go b/backend/internal/handler/sora_client_handler.go new file mode 100644 index 00000000..80acc833 --- /dev/null +++ b/backend/internal/handler/sora_client_handler.go @@ -0,0 +1,979 @@ +package handler + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +const ( + // 上游模型缓存 TTL + modelCacheTTL = 1 * time.Hour // 上游获取成功 + modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地) +) + +// SoraClientHandler 处理 Sora 客户端 API 请求。 +type SoraClientHandler struct { + genService *service.SoraGenerationService + quotaService *service.SoraQuotaService + s3Storage *service.SoraS3Storage + soraGatewayService *service.SoraGatewayService + gatewayService *service.GatewayService + mediaStorage *service.SoraMediaStorage + apiKeyService *service.APIKeyService + + // 上游模型缓存 + modelCacheMu sync.RWMutex + cachedFamilies []service.SoraModelFamily + modelCacheTime time.Time + modelCacheUpstream bool // 是否来自上游(决定 TTL) +} + +// NewSoraClientHandler 创建 Sora 客户端 Handler。 +func NewSoraClientHandler( + genService *service.SoraGenerationService, + quotaService *service.SoraQuotaService, + s3Storage *service.SoraS3Storage, + soraGatewayService *service.SoraGatewayService, + gatewayService *service.GatewayService, + mediaStorage *service.SoraMediaStorage, + apiKeyService *service.APIKeyService, +) *SoraClientHandler { + return &SoraClientHandler{ + genService: genService, + quotaService: quotaService, + s3Storage: s3Storage, + soraGatewayService: soraGatewayService, + gatewayService: gatewayService, + mediaStorage: mediaStorage, + apiKeyService: apiKeyService, + } +} + +// GenerateRequest 生成请求。 +type GenerateRequest struct { + Model string `json:"model" binding:"required"` + Prompt string `json:"prompt" binding:"required"` + MediaType string `json:"media_type"` // video / image,默认 video + VideoCount int `json:"video_count,omitempty"` // 视频数量(1-3) + ImageInput string `json:"image_input,omitempty"` // 参考图(base64 或 URL) + APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID +} + +// Generate 异步生成 — 创建 pending 记录后立即返回。 +// POST /api/v1/sora/generate +func (h *SoraClientHandler) Generate(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + var req GenerateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error()) + return + } + + if req.MediaType == "" { + req.MediaType = "video" + } + req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount) + + // 并发数检查(最多 3 个) + activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + if activeCount >= 3 { + response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个") + return + } + + // 配额检查(粗略检查,实际文件大小在上传后才知道) + if h.quotaService != nil { + if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil { + var quotaErr *service.QuotaExceededError + if errors.As(err, "aErr) { + response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间") + return + } + response.Error(c, http.StatusForbidden, err.Error()) + return + } + } + + // 获取 API Key ID 和 Group ID + var apiKeyID *int64 + var groupID *int64 + + if req.APIKeyID != nil && h.apiKeyService != nil { + // 前端传递了 api_key_id,需要校验 + apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID) + if err != nil { + response.Error(c, http.StatusBadRequest, "API Key 不存在") + return + } + if apiKey.UserID != userID { + response.Error(c, http.StatusForbidden, "API Key 不属于当前用户") + return + } + if apiKey.Status != service.StatusAPIKeyActive { + response.Error(c, http.StatusForbidden, "API Key 不可用") + return + } + apiKeyID = &apiKey.ID + groupID = apiKey.GroupID + } else if id, ok := c.Get("api_key_id"); ok { + // 兼容 API Key 认证路径(/sora/v1/ 网关路由) + if v, ok := id.(int64); ok { + apiKeyID = &v + } + } + + gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType) + if err != nil { + if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) { + response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个") + return + } + response.ErrorFrom(c, err) + return + } + + // 启动后台异步生成 goroutine + go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount) + + response.Success(c, gin.H{ + "generation_id": gen.ID, + "status": gen.Status, + }) +} + +// processGeneration 后台异步执行 Sora 生成任务。 +// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。 +func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + // 标记为生成中 + if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil { + if errors.Is(err, service.ErrSoraGenerationStateConflict) { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID) + return + } + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err) + return + } + + logger.LegacyPrintf( + "handler.sora_client", + "[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d", + genID, + userID, + groupIDForLog(groupID), + model, + mediaType, + videoCount, + strings.TrimSpace(imageInput) != "", + len(strings.TrimSpace(prompt)), + ) + + // 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底 + if groupID == nil { + ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora) + } + + if h.gatewayService == nil { + _ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化") + return + } + + // 选择 Sora 账号 + account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model) + if err != nil { + logger.LegacyPrintf( + "handler.sora_client", + "[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v", + genID, + userID, + groupIDForLog(groupID), + model, + err, + ) + _ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error()) + return + } + logger.LegacyPrintf( + "handler.sora_client", + "[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s", + genID, + userID, + groupIDForLog(groupID), + model, + account.ID, + account.Name, + account.Platform, + account.Type, + ) + + // 构建 chat completions 请求体(非流式) + body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount)) + + if h.soraGatewayService == nil { + _ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化") + return + } + + // 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL) + recorder := httptest.NewRecorder() + mockGinCtx, _ := gin.CreateTestContext(recorder) + mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil) + + // 调用 Forward(非流式) + result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false) + if err != nil { + logger.LegacyPrintf( + "handler.sora_client", + "[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v", + genID, + account.ID, + model, + recorder.Code, + trimForLog(recorder.Body.String(), 400), + err, + ) + // 检查是否已取消 + gen, _ := h.genService.GetByID(ctx, genID, userID) + if gen != nil && gen.Status == service.SoraGenStatusCancelled { + return + } + _ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error()) + return + } + + // 提取媒体 URL(优先从 ForwardResult,其次从响应体解析) + mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder) + if mediaURL == "" { + logger.LegacyPrintf( + "handler.sora_client", + "[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s", + genID, + account.ID, + model, + recorder.Code, + trimForLog(recorder.Body.String(), 400), + ) + _ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL") + return + } + + // 检查任务是否已被取消 + gen, _ := h.genService.GetByID(ctx, genID, userID) + if gen != nil && gen.Status == service.SoraGenStatusCancelled { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID) + return + } + + // 三层降级存储:S3 → 本地 → 上游临时 URL + storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs) + + usageAdded := false + if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil { + if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil { + h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) + var quotaErr *service.QuotaExceededError + if errors.As(err, "aErr) { + _ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间") + return + } + _ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error()) + return + } + usageAdded = true + } + + // 存储完成后再做一次取消检查,防止取消被 completed 覆盖。 + gen, _ = h.genService.GetByID(ctx, genID, userID) + if gen != nil && gen.Status == service.SoraGenStatusCancelled { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID) + h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) + if usageAdded && h.quotaService != nil { + _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize) + } + return + } + + // 标记完成 + if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil { + if errors.Is(err, service.ErrSoraGenerationStateConflict) { + h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) + if usageAdded && h.quotaService != nil { + _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize) + } + return + } + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err) + return + } + + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize) +} + +// storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。 +func (h *SoraClientHandler) storeMediaWithDegradation( + ctx context.Context, userID int64, mediaType string, + mediaURL string, mediaURLs []string, +) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) { + urls := mediaURLs + if len(urls) == 0 { + urls = []string{mediaURL} + } + + // 第一层:尝试 S3 + if h.s3Storage != nil && h.s3Storage.Enabled(ctx) { + keys := make([]string, 0, len(urls)) + var totalSize int64 + allOK := true + for _, u := range urls { + key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u) + if err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err) + allOK = false + // 清理已上传的文件 + if len(keys) > 0 { + _ = h.s3Storage.DeleteObjects(ctx, keys) + } + break + } + keys = append(keys, key) + totalSize += size + } + if allOK && len(keys) > 0 { + accessURLs := make([]string, 0, len(keys)) + for _, key := range keys { + accessURL, err := h.s3Storage.GetAccessURL(ctx, key) + if err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err) + _ = h.s3Storage.DeleteObjects(ctx, keys) + allOK = false + break + } + accessURLs = append(accessURLs, accessURL) + } + if allOK && len(accessURLs) > 0 { + return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize + } + } + } + + // 第二层:尝试本地存储 + if h.mediaStorage != nil && h.mediaStorage.Enabled() { + storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls) + if err == nil && len(storedPaths) > 0 { + firstPath := storedPaths[0] + totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths) + if sizeErr != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr) + } + return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize + } + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err) + } + + // 第三层:保留上游临时 URL + return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0 +} + +// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。 +func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte { + body := map[string]any{ + "model": model, + "messages": []map[string]string{ + {"role": "user", "content": prompt}, + }, + "stream": false, + } + if imageInput != "" { + body["image_input"] = imageInput + } + if videoCount > 1 { + body["video_count"] = videoCount + } + b, _ := json.Marshal(body) + return b +} + +func normalizeVideoCount(mediaType string, videoCount int) int { + if mediaType != "video" { + return 1 + } + if videoCount <= 0 { + return 1 + } + if videoCount > 3 { + return 3 + } + return videoCount +} + +// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。 +// OAuth 路径:ForwardResult.MediaURL 已填充。 +// APIKey 路径:需从响应体解析 media_url / media_urls 字段。 +func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) { + // 优先从 ForwardResult 获取(OAuth 路径) + if result != nil && result.MediaURL != "" { + // 尝试从响应体获取完整 URL 列表 + if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 { + return urls[0], urls + } + return result.MediaURL, []string{result.MediaURL} + } + + // 从响应体解析(APIKey 路径) + if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 { + return urls[0], urls + } + + return "", nil +} + +// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。 +func parseMediaURLsFromBody(body []byte) []string { + if len(body) == 0 { + return nil + } + var resp map[string]any + if err := json.Unmarshal(body, &resp); err != nil { + return nil + } + + // 优先 media_urls(多图数组) + if rawURLs, ok := resp["media_urls"]; ok { + if arr, ok := rawURLs.([]any); ok && len(arr) > 0 { + urls := make([]string, 0, len(arr)) + for _, item := range arr { + if s, ok := item.(string); ok && s != "" { + urls = append(urls, s) + } + } + if len(urls) > 0 { + return urls + } + } + } + + // 回退到 media_url(单个 URL) + if url, ok := resp["media_url"].(string); ok && url != "" { + return []string{url} + } + + return nil +} + +// ListGenerations 查询生成记录列表。 +// GET /api/v1/sora/generations +func (h *SoraClientHandler) ListGenerations(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + + params := service.SoraGenerationListParams{ + UserID: userID, + Status: c.Query("status"), + StorageType: c.Query("storage_type"), + MediaType: c.Query("media_type"), + Page: page, + PageSize: pageSize, + } + + gens, total, err := h.genService.List(c.Request.Context(), params) + if err != nil { + response.ErrorFrom(c, err) + return + } + + // 为 S3 记录动态生成预签名 URL + for _, gen := range gens { + _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen) + } + + response.Success(c, gin.H{ + "data": gens, + "total": total, + "page": page, + }) +} + +// GetGeneration 查询生成记录详情。 +// GET /api/v1/sora/generations/:id +func (h *SoraClientHandler) GetGeneration(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.Error(c, http.StatusBadRequest, "无效的 ID") + return + } + + gen, err := h.genService.GetByID(c.Request.Context(), id, userID) + if err != nil { + response.Error(c, http.StatusNotFound, err.Error()) + return + } + + _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen) + response.Success(c, gen) +} + +// DeleteGeneration 删除生成记录。 +// DELETE /api/v1/sora/generations/:id +func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.Error(c, http.StatusBadRequest, "无效的 ID") + return + } + + gen, err := h.genService.GetByID(c.Request.Context(), id, userID) + if err != nil { + response.Error(c, http.StatusNotFound, err.Error()) + return + } + + // 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。 + if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil { + paths := gen.MediaURLs + if len(paths) == 0 && gen.MediaURL != "" { + paths = []string{gen.MediaURL} + } + if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err) + } + } + + if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil { + response.Error(c, http.StatusNotFound, err.Error()) + return + } + + response.Success(c, gin.H{"message": "已删除"}) +} + +// GetQuota 查询用户存储配额。 +// GET /api/v1/sora/quota +func (h *SoraClientHandler) GetQuota(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + if h.quotaService == nil { + response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"}) + return + } + + quota, err := h.quotaService.GetQuota(c.Request.Context(), userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, quota) +} + +// CancelGeneration 取消生成任务。 +// POST /api/v1/sora/generations/:id/cancel +func (h *SoraClientHandler) CancelGeneration(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.Error(c, http.StatusBadRequest, "无效的 ID") + return + } + + // 权限校验 + gen, err := h.genService.GetByID(c.Request.Context(), id, userID) + if err != nil { + response.Error(c, http.StatusNotFound, err.Error()) + return + } + _ = gen + + if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil { + if errors.Is(err, service.ErrSoraGenerationNotActive) { + response.Error(c, http.StatusConflict, "任务已结束,无法取消") + return + } + response.Error(c, http.StatusBadRequest, err.Error()) + return + } + + response.Success(c, gin.H{"message": "已取消"}) +} + +// SaveToStorage 手动保存 upstream 记录到 S3。 +// POST /api/v1/sora/generations/:id/save +func (h *SoraClientHandler) SaveToStorage(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.Error(c, http.StatusBadRequest, "无效的 ID") + return + } + + gen, err := h.genService.GetByID(c.Request.Context(), id, userID) + if err != nil { + response.Error(c, http.StatusNotFound, err.Error()) + return + } + + if gen.StorageType != service.SoraStorageTypeUpstream { + response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存") + return + } + if gen.MediaURL == "" { + response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期") + return + } + + if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) { + response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员") + return + } + + sourceURLs := gen.MediaURLs + if len(sourceURLs) == 0 && gen.MediaURL != "" { + sourceURLs = []string{gen.MediaURL} + } + if len(sourceURLs) == 0 { + response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期") + return + } + + uploadedKeys := make([]string, 0, len(sourceURLs)) + accessURLs := make([]string, 0, len(sourceURLs)) + var totalSize int64 + + for _, sourceURL := range sourceURLs { + objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL) + if uploadErr != nil { + if len(uploadedKeys) > 0 { + _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) + } + var upstreamErr *service.UpstreamDownloadError + if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) { + response.Error(c, http.StatusGone, "媒体链接已过期,无法保存") + return + } + response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error()) + return + } + accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey) + if err != nil { + uploadedKeys = append(uploadedKeys, objectKey) + _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) + response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error()) + return + } + uploadedKeys = append(uploadedKeys, objectKey) + accessURLs = append(accessURLs, accessURL) + totalSize += fileSize + } + + usageAdded := false + if totalSize > 0 && h.quotaService != nil { + if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil { + _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) + var quotaErr *service.QuotaExceededError + if errors.As(err, "aErr) { + response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间") + return + } + response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error()) + return + } + usageAdded = true + } + + if err := h.genService.UpdateStorageForCompleted( + c.Request.Context(), + id, + accessURLs[0], + accessURLs, + service.SoraStorageTypeS3, + uploadedKeys, + totalSize, + ); err != nil { + _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) + if usageAdded && h.quotaService != nil { + _ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize) + } + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{ + "message": "已保存到 S3", + "object_key": uploadedKeys[0], + "object_keys": uploadedKeys, + }) +} + +// GetStorageStatus 返回存储状态。 +// GET /api/v1/sora/storage-status +func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) { + s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context()) + s3Healthy := false + if s3Enabled { + s3Healthy = h.s3Storage.IsHealthy(c.Request.Context()) + } + localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled() + response.Success(c, gin.H{ + "s3_enabled": s3Enabled, + "s3_healthy": s3Healthy, + "local_enabled": localEnabled, + }) +} + +func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) { + switch storageType { + case service.SoraStorageTypeS3: + if h.s3Storage != nil && len(s3Keys) > 0 { + if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err) + } + } + case service.SoraStorageTypeLocal: + if h.mediaStorage != nil && len(localPaths) > 0 { + if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err) + } + } + } +} + +// getUserIDFromContext 从 gin 上下文中提取用户 ID。 +func getUserIDFromContext(c *gin.Context) int64 { + if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 { + return subject.UserID + } + + if id, ok := c.Get("user_id"); ok { + switch v := id.(type) { + case int64: + return v + case float64: + return int64(v) + case string: + n, _ := strconv.ParseInt(v, 10, 64) + return n + } + } + // 尝试从 JWT claims 获取 + if id, ok := c.Get("userID"); ok { + if v, ok := id.(int64); ok { + return v + } + } + return 0 +} + +func groupIDForLog(groupID *int64) int64 { + if groupID == nil { + return 0 + } + return *groupID +} + +func trimForLog(raw string, maxLen int) string { + trimmed := strings.TrimSpace(raw) + if maxLen <= 0 || len(trimmed) <= maxLen { + return trimmed + } + return trimmed[:maxLen] + "...(truncated)" +} + +// GetModels 获取可用 Sora 模型家族列表。 +// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。 +// GET /api/v1/sora/models +func (h *SoraClientHandler) GetModels(c *gin.Context) { + families := h.getModelFamilies(c.Request.Context()) + response.Success(c, families) +} + +// getModelFamilies 获取模型家族列表(带缓存)。 +func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily { + // 读锁检查缓存 + h.modelCacheMu.RLock() + ttl := modelCacheTTL + if !h.modelCacheUpstream { + ttl = modelCacheFailedTTL + } + if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl { + families := h.cachedFamilies + h.modelCacheMu.RUnlock() + return families + } + h.modelCacheMu.RUnlock() + + // 写锁更新缓存 + h.modelCacheMu.Lock() + defer h.modelCacheMu.Unlock() + + // double-check + ttl = modelCacheTTL + if !h.modelCacheUpstream { + ttl = modelCacheFailedTTL + } + if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl { + return h.cachedFamilies + } + + // 尝试从上游获取 + families, err := h.fetchUpstreamModels(ctx) + if err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err) + families = service.BuildSoraModelFamilies() + h.cachedFamilies = families + h.modelCacheTime = time.Now() + h.modelCacheUpstream = false + return families + } + + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families)) + h.cachedFamilies = families + h.modelCacheTime = time.Now() + h.modelCacheUpstream = true + return families +} + +// fetchUpstreamModels 从上游 Sora API 获取模型列表。 +func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) { + if h.gatewayService == nil { + return nil, fmt.Errorf("gatewayService 未初始化") + } + + // 设置 ForcePlatform 用于 Sora 账号选择 + ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora) + + // 选择一个 Sora 账号 + account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s") + if err != nil { + return nil, fmt.Errorf("选择 Sora 账号失败: %w", err) + } + + // 仅支持 API Key 类型账号 + if account.Type != service.AccountTypeAPIKey { + return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type) + } + + apiKey := account.GetCredential("api_key") + if apiKey == "" { + return nil, fmt.Errorf("账号缺少 api_key") + } + + baseURL := account.GetBaseURL() + if baseURL == "" { + return nil, fmt.Errorf("账号缺少 base_url") + } + + // 构建上游模型列表请求 + modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models" + + reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil) + if err != nil { + return nil, fmt.Errorf("创建请求失败: %w", err) + } + req.Header.Set("Authorization", "Bearer "+apiKey) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("请求上游失败: %w", err) + } + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode) + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024)) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %w", err) + } + + // 解析 OpenAI 格式的模型列表 + var modelsResp struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + if err := json.Unmarshal(body, &modelsResp); err != nil { + return nil, fmt.Errorf("解析响应失败: %w", err) + } + + if len(modelsResp.Data) == 0 { + return nil, fmt.Errorf("上游返回空模型列表") + } + + // 提取模型 ID + modelIDs := make([]string, 0, len(modelsResp.Data)) + for _, m := range modelsResp.Data { + modelIDs = append(modelIDs, m.ID) + } + + // 转换为模型家族 + families := service.BuildSoraModelFamiliesFromIDs(modelIDs) + if len(families) == 0 { + return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族") + } + + return families, nil +} diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go new file mode 100644 index 00000000..d2d9790d --- /dev/null +++ b/backend/internal/handler/sora_client_handler_test.go @@ -0,0 +1,3153 @@ +//go:build unit + +package handler + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +// ==================== Stub: SoraGenerationRepository ==================== + +var _ service.SoraGenerationRepository = (*stubSoraGenRepo)(nil) + +type stubSoraGenRepo struct { + gens map[int64]*service.SoraGeneration + nextID int64 + createErr error + getErr error + updateErr error + deleteErr error + listErr error + countErr error + countValue int64 + + // 条件性 Update 失败:前 updateFailAfterN 次成功,之后失败 + updateCallCount *int32 + updateFailAfterN int32 + + // 条件性 GetByID 状态覆盖:前 getByIDOverrideAfterN 次正常返回,之后返回 overrideStatus + getByIDCallCount int32 + getByIDOverrideAfterN int32 // 0 = 不覆盖 + getByIDOverrideStatus string +} + +func newStubSoraGenRepo() *stubSoraGenRepo { + return &stubSoraGenRepo{gens: make(map[int64]*service.SoraGeneration), nextID: 1} +} + +func (r *stubSoraGenRepo) Create(_ context.Context, gen *service.SoraGeneration) error { + if r.createErr != nil { + return r.createErr + } + gen.ID = r.nextID + r.nextID++ + r.gens[gen.ID] = gen + return nil +} +func (r *stubSoraGenRepo) GetByID(_ context.Context, id int64) (*service.SoraGeneration, error) { + if r.getErr != nil { + return nil, r.getErr + } + gen, ok := r.gens[id] + if !ok { + return nil, fmt.Errorf("not found") + } + // 条件性状态覆盖:模拟外部取消等场景 + if r.getByIDOverrideAfterN > 0 { + n := atomic.AddInt32(&r.getByIDCallCount, 1) + if n > r.getByIDOverrideAfterN { + cp := *gen + cp.Status = r.getByIDOverrideStatus + return &cp, nil + } + } + return gen, nil +} +func (r *stubSoraGenRepo) Update(_ context.Context, gen *service.SoraGeneration) error { + // 条件性失败:前 N 次成功,之后失败 + if r.updateCallCount != nil { + n := atomic.AddInt32(r.updateCallCount, 1) + if n > r.updateFailAfterN { + return fmt.Errorf("conditional update error (call #%d)", n) + } + } + if r.updateErr != nil { + return r.updateErr + } + r.gens[gen.ID] = gen + return nil +} +func (r *stubSoraGenRepo) Delete(_ context.Context, id int64) error { + if r.deleteErr != nil { + return r.deleteErr + } + delete(r.gens, id) + return nil +} +func (r *stubSoraGenRepo) List(_ context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) { + if r.listErr != nil { + return nil, 0, r.listErr + } + var result []*service.SoraGeneration + for _, gen := range r.gens { + if gen.UserID != params.UserID { + continue + } + result = append(result, gen) + } + return result, int64(len(result)), nil +} +func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []string) (int64, error) { + if r.countErr != nil { + return 0, r.countErr + } + return r.countValue, nil +} + +// ==================== 辅助函数 ==================== + +func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler { + genService := service.NewSoraGenerationService(repo, nil, nil) + return &SoraClientHandler{genService: genService} +} + +func makeGinContext(method, path, body string, userID int64) (*gin.Context, *httptest.ResponseRecorder) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + if body != "" { + c.Request = httptest.NewRequest(method, path, strings.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + } else { + c.Request = httptest.NewRequest(method, path, nil) + } + if userID > 0 { + c.Set("user_id", userID) + } + return c, rec +} + +func parseResponse(t *testing.T, rec *httptest.ResponseRecorder) map[string]any { + t.Helper() + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + return resp +} + +// ==================== 纯函数测试: buildAsyncRequestBody ==================== + +func TestBuildAsyncRequestBody(t *testing.T) { + body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 1) + var parsed map[string]any + require.NoError(t, json.Unmarshal(body, &parsed)) + require.Equal(t, "sora2-landscape-10s", parsed["model"]) + require.Equal(t, false, parsed["stream"]) + + msgs := parsed["messages"].([]any) + require.Len(t, msgs, 1) + msg := msgs[0].(map[string]any) + require.Equal(t, "user", msg["role"]) + require.Equal(t, "一只猫在跳舞", msg["content"]) +} + +func TestBuildAsyncRequestBody_EmptyPrompt(t *testing.T) { + body := buildAsyncRequestBody("gpt-image", "", "", 1) + var parsed map[string]any + require.NoError(t, json.Unmarshal(body, &parsed)) + require.Equal(t, "gpt-image", parsed["model"]) + msgs := parsed["messages"].([]any) + msg := msgs[0].(map[string]any) + require.Equal(t, "", msg["content"]) +} + +func TestBuildAsyncRequestBody_WithImageInput(t *testing.T) { + body := buildAsyncRequestBody("gpt-image", "一只猫", "https://example.com/ref.png", 1) + var parsed map[string]any + require.NoError(t, json.Unmarshal(body, &parsed)) + require.Equal(t, "https://example.com/ref.png", parsed["image_input"]) +} + +func TestBuildAsyncRequestBody_WithVideoCount(t *testing.T) { + body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 3) + var parsed map[string]any + require.NoError(t, json.Unmarshal(body, &parsed)) + require.Equal(t, float64(3), parsed["video_count"]) +} + +func TestNormalizeVideoCount(t *testing.T) { + require.Equal(t, 1, normalizeVideoCount("video", 0)) + require.Equal(t, 2, normalizeVideoCount("video", 2)) + require.Equal(t, 3, normalizeVideoCount("video", 5)) + require.Equal(t, 1, normalizeVideoCount("image", 3)) +} + +// ==================== 纯函数测试: parseMediaURLsFromBody ==================== + +func TestParseMediaURLsFromBody_MediaURLs(t *testing.T) { + urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","https://a.com/2.mp4"]}`)) + require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) +} + +func TestParseMediaURLsFromBody_SingleMediaURL(t *testing.T) { + urls := parseMediaURLsFromBody([]byte(`{"media_url":"https://a.com/video.mp4"}`)) + require.Equal(t, []string{"https://a.com/video.mp4"}, urls) +} + +func TestParseMediaURLsFromBody_EmptyBody(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody(nil)) + require.Nil(t, parseMediaURLsFromBody([]byte{})) +} + +func TestParseMediaURLsFromBody_InvalidJSON(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte("not json"))) +} + +func TestParseMediaURLsFromBody_NoMediaFields(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte(`{"data":"something"}`))) +} + +func TestParseMediaURLsFromBody_EmptyMediaURL(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":""}`))) +} + +func TestParseMediaURLsFromBody_EmptyMediaURLs(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":[]}`))) +} + +func TestParseMediaURLsFromBody_MediaURLsPriority(t *testing.T) { + body := `{"media_url":"https://single.com/1.mp4","media_urls":["https://multi.com/a.mp4","https://multi.com/b.mp4"]}` + urls := parseMediaURLsFromBody([]byte(body)) + require.Len(t, urls, 2) + require.Equal(t, "https://multi.com/a.mp4", urls[0]) +} + +func TestParseMediaURLsFromBody_FilterEmpty(t *testing.T) { + urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","","https://a.com/2.mp4"]}`)) + require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) +} + +func TestParseMediaURLsFromBody_AllEmpty(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":["",""]}`))) +} + +func TestParseMediaURLsFromBody_NonStringArray(t *testing.T) { + // media_urls 不是 string 数组 + require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":"not-array"}`))) +} + +func TestParseMediaURLsFromBody_MediaURLNotString(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":123}`))) +} + +// ==================== 纯函数测试: extractMediaURLsFromResult ==================== + +func TestExtractMediaURLsFromResult_OAuthPath(t *testing.T) { + result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"} + recorder := httptest.NewRecorder() + url, urls := extractMediaURLsFromResult(result, recorder) + require.Equal(t, "https://oauth.com/video.mp4", url) + require.Equal(t, []string{"https://oauth.com/video.mp4"}, urls) +} + +func TestExtractMediaURLsFromResult_OAuthWithBody(t *testing.T) { + result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"} + recorder := httptest.NewRecorder() + _, _ = recorder.Write([]byte(`{"media_urls":["https://body.com/1.mp4","https://body.com/2.mp4"]}`)) + url, urls := extractMediaURLsFromResult(result, recorder) + require.Equal(t, "https://body.com/1.mp4", url) + require.Len(t, urls, 2) +} + +func TestExtractMediaURLsFromResult_APIKeyPath(t *testing.T) { + recorder := httptest.NewRecorder() + _, _ = recorder.Write([]byte(`{"media_url":"https://upstream.com/video.mp4"}`)) + url, urls := extractMediaURLsFromResult(nil, recorder) + require.Equal(t, "https://upstream.com/video.mp4", url) + require.Equal(t, []string{"https://upstream.com/video.mp4"}, urls) +} + +func TestExtractMediaURLsFromResult_NilResultEmptyBody(t *testing.T) { + recorder := httptest.NewRecorder() + url, urls := extractMediaURLsFromResult(nil, recorder) + require.Empty(t, url) + require.Nil(t, urls) +} + +func TestExtractMediaURLsFromResult_EmptyMediaURL(t *testing.T) { + result := &service.ForwardResult{MediaURL: ""} + recorder := httptest.NewRecorder() + url, urls := extractMediaURLsFromResult(result, recorder) + require.Empty(t, url) + require.Nil(t, urls) +} + +// ==================== getUserIDFromContext ==================== + +func TestGetUserIDFromContext_Int64(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("user_id", int64(42)) + require.Equal(t, int64(42), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_AuthSubject(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 777}) + require.Equal(t, int64(777), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_Float64(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("user_id", float64(99)) + require.Equal(t, int64(99), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_String(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("user_id", "123") + require.Equal(t, int64(123), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_UserIDFallback(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("userID", int64(55)) + require.Equal(t, int64(55), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_NoID(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + require.Equal(t, int64(0), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_InvalidString(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("user_id", "not-a-number") + require.Equal(t, int64(0), getUserIDFromContext(c)) +} + +// ==================== Handler: Generate ==================== + +func TestGenerate_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 0) + h.Generate(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestGenerate_BadRequest_MissingModel(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestGenerate_BadRequest_MissingPrompt(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestGenerate_BadRequest_InvalidJSON(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{invalid`, 1) + h.Generate(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestGenerate_TooManyRequests(t *testing.T) { + repo := newStubSoraGenRepo() + repo.countValue = 3 + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusTooManyRequests, rec.Code) +} + +func TestGenerate_CountError(t *testing.T) { + repo := newStubSoraGenRepo() + repo.countErr = fmt.Errorf("db error") + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +func TestGenerate_Success(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"测试生成"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.NotZero(t, data["generation_id"]) + require.Equal(t, "pending", data["status"]) +} + +func TestGenerate_DefaultMediaType(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "video", repo.gens[1].MediaType) +} + +func TestGenerate_ImageMediaType(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"gpt-image","prompt":"test","media_type":"image"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "image", repo.gens[1].MediaType) +} + +func TestGenerate_CreatePendingError(t *testing.T) { + repo := newStubSoraGenRepo() + repo.createErr = fmt.Errorf("create failed") + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +func TestGenerate_NilQuotaServiceSkipsCheck(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +func TestGenerate_APIKeyInContext(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + c.Set("api_key_id", int64(42)) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.gens[1].APIKeyID) + require.Equal(t, int64(42), *repo.gens[1].APIKeyID) +} + +func TestGenerate_NoAPIKeyInContext(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Nil(t, repo.gens[1].APIKeyID) +} + +func TestGenerate_ConcurrencyBoundary(t *testing.T) { + // activeCount == 2 应该允许 + repo := newStubSoraGenRepo() + repo.countValue = 2 + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +// ==================== Handler: ListGenerations ==================== + +func TestListGenerations_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 0) + h.ListGenerations(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestListGenerations_Success(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream"} + repo.gens[2] = &service.SoraGeneration{ID: 2, UserID: 1, Model: "gpt-image", Status: "pending", StorageType: "none"} + repo.nextID = 3 + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("GET", "/api/v1/sora/generations?page=1&page_size=10", "", 1) + h.ListGenerations(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + items := data["data"].([]any) + require.Len(t, items, 2) + require.Equal(t, float64(2), data["total"]) +} + +func TestListGenerations_ListError(t *testing.T) { + repo := newStubSoraGenRepo() + repo.listErr = fmt.Errorf("db error") + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1) + h.ListGenerations(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +func TestListGenerations_DefaultPagination(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + // 不传分页参数,应默认 page=1 page_size=20 + c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1) + h.ListGenerations(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, float64(1), data["page"]) +} + +// ==================== Handler: GetGeneration ==================== + +func TestGetGeneration_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 0) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.GetGeneration(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestGetGeneration_InvalidID(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/generations/abc", "", 1) + c.Params = gin.Params{{Key: "id", Value: "abc"}} + h.GetGeneration(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestGetGeneration_NotFound(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/generations/999", "", 1) + c.Params = gin.Params{{Key: "id", Value: "999"}} + h.GetGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestGetGeneration_WrongUser(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.GetGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestGetGeneration_Success(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.GetGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, float64(1), data["id"]) +} + +// ==================== Handler: DeleteGeneration ==================== + +func TestDeleteGeneration_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 0) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestDeleteGeneration_InvalidID(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/abc", "", 1) + c.Params = gin.Params{{Key: "id", Value: "abc"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDeleteGeneration_NotFound(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/999", "", 1) + c.Params = gin.Params{{Key: "id", Value: "999"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestDeleteGeneration_WrongUser(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestDeleteGeneration_Success(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) + _, exists := repo.gens[1] + require.False(t, exists) +} + +// ==================== Handler: CancelGeneration ==================== + +func TestCancelGeneration_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 0) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestCancelGeneration_InvalidID(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "abc"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestCancelGeneration_NotFound(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "999"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestCancelGeneration_WrongUser(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "pending"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestCancelGeneration_Pending(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "cancelled", repo.gens[1].Status) +} + +func TestCancelGeneration_Generating(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "generating"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "cancelled", repo.gens[1].Status) +} + +func TestCancelGeneration_Completed(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusConflict, rec.Code) +} + +func TestCancelGeneration_Failed(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "failed"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusConflict, rec.Code) +} + +func TestCancelGeneration_Cancelled(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusConflict, rec.Code) +} + +// ==================== Handler: GetQuota ==================== + +func TestGetQuota_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 0) + h.GetQuota(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestGetQuota_NilQuotaService(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1) + h.GetQuota(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, "unlimited", data["source"]) +} + +// ==================== Handler: GetModels ==================== + +func TestGetModels(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/models", "", 0) + h.GetModels(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].([]any) + require.Len(t, data, 4) + // 验证类型分布 + videoCount, imageCount := 0, 0 + for _, item := range data { + m := item.(map[string]any) + if m["type"] == "video" { + videoCount++ + } else if m["type"] == "image" { + imageCount++ + } + } + require.Equal(t, 3, videoCount) + require.Equal(t, 1, imageCount) +} + +// ==================== Handler: GetStorageStatus ==================== + +func TestGetStorageStatus_NilS3(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) + h.GetStorageStatus(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, false, data["s3_enabled"]) + require.Equal(t, false, data["s3_healthy"]) + require.Equal(t, false, data["local_enabled"]) +} + +func TestGetStorageStatus_LocalEnabled(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-storage-status-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + h := &SoraClientHandler{mediaStorage: mediaStorage} + + c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) + h.GetStorageStatus(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, false, data["s3_enabled"]) + require.Equal(t, false, data["s3_healthy"]) + require.Equal(t, true, data["local_enabled"]) +} + +// ==================== Handler: SaveToStorage ==================== + +func TestSaveToStorage_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 0) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestSaveToStorage_InvalidID(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "abc"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestSaveToStorage_NotFound(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "999"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestSaveToStorage_NotUpstream(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "s3", MediaURL: "https://example.com/v.mp4"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestSaveToStorage_EmptyMediaURL(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: ""} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestSaveToStorage_S3Nil(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusServiceUnavailable, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, fmt.Sprint(resp["message"]), "云存储") +} + +func TestSaveToStorage_WrongUser(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +// ==================== storeMediaWithDegradation — nil guard 路径 ==================== + +func TestStoreMediaWithDegradation_NilS3NilMedia(t *testing.T) { + h := &SoraClientHandler{} + url, urls, storageType, keys, size := h.storeMediaWithDegradation( + context.Background(), 1, "video", "https://upstream.com/v.mp4", nil, + ) + require.Equal(t, service.SoraStorageTypeUpstream, storageType) + require.Equal(t, "https://upstream.com/v.mp4", url) + require.Equal(t, []string{"https://upstream.com/v.mp4"}, urls) + require.Nil(t, keys) + require.Equal(t, int64(0), size) +} + +func TestStoreMediaWithDegradation_NilGuardsMultiURL(t *testing.T) { + h := &SoraClientHandler{} + url, urls, storageType, keys, size := h.storeMediaWithDegradation( + context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, + ) + require.Equal(t, service.SoraStorageTypeUpstream, storageType) + require.Equal(t, "https://a.com/1.mp4", url) + require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) + require.Nil(t, keys) + require.Equal(t, int64(0), size) +} + +func TestStoreMediaWithDegradation_EmptyMediaURLsFallback(t *testing.T) { + h := &SoraClientHandler{} + url, _, storageType, _, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{}, + ) + require.Equal(t, service.SoraStorageTypeUpstream, storageType) + require.Equal(t, "https://upstream.com/v.mp4", url) +} + +// ==================== Stub: UserRepository (用于 SoraQuotaService) ==================== + +var _ service.UserRepository = (*stubUserRepoForHandler)(nil) + +type stubUserRepoForHandler struct { + users map[int64]*service.User + updateErr error +} + +func newStubUserRepoForHandler() *stubUserRepoForHandler { + return &stubUserRepoForHandler{users: make(map[int64]*service.User)} +} + +func (r *stubUserRepoForHandler) GetByID(_ context.Context, id int64) (*service.User, error) { + if u, ok := r.users[id]; ok { + return u, nil + } + return nil, fmt.Errorf("user not found") +} +func (r *stubUserRepoForHandler) Update(_ context.Context, user *service.User) error { + if r.updateErr != nil { + return r.updateErr + } + r.users[user.ID] = user + return nil +} +func (r *stubUserRepoForHandler) Create(context.Context, *service.User) error { return nil } +func (r *stubUserRepoForHandler) GetByEmail(context.Context, string) (*service.User, error) { + return nil, nil +} +func (r *stubUserRepoForHandler) GetFirstAdmin(context.Context) (*service.User, error) { + return nil, nil +} +func (r *stubUserRepoForHandler) Delete(context.Context, int64) error { return nil } +func (r *stubUserRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubUserRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubUserRepoForHandler) UpdateBalance(context.Context, int64, float64) error { return nil } +func (r *stubUserRepoForHandler) DeductBalance(context.Context, int64, float64) error { return nil } +func (r *stubUserRepoForHandler) UpdateConcurrency(context.Context, int64, int) error { return nil } +func (r *stubUserRepoForHandler) ExistsByEmail(context.Context, string) (bool, error) { + return false, nil +} +func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil } +func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil } +func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64, int64) error { + return nil +} + +// ==================== NewSoraClientHandler ==================== + +func TestNewSoraClientHandler(t *testing.T) { + h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil) + require.NotNil(t, h) +} + +func TestNewSoraClientHandler_WithAPIKeyService(t *testing.T) { + h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil) + require.NotNil(t, h) + require.Nil(t, h.apiKeyService) +} + +// ==================== Stub: APIKeyRepository (用于 API Key 校验测试) ==================== + +var _ service.APIKeyRepository = (*stubAPIKeyRepoForHandler)(nil) + +type stubAPIKeyRepoForHandler struct { + keys map[int64]*service.APIKey + getErr error +} + +func newStubAPIKeyRepoForHandler() *stubAPIKeyRepoForHandler { + return &stubAPIKeyRepoForHandler{keys: make(map[int64]*service.APIKey)} +} + +func (r *stubAPIKeyRepoForHandler) GetByID(_ context.Context, id int64) (*service.APIKey, error) { + if r.getErr != nil { + return nil, r.getErr + } + if k, ok := r.keys[id]; ok { + return k, nil + } + return nil, fmt.Errorf("api key not found: %d", id) +} +func (r *stubAPIKeyRepoForHandler) Create(context.Context, *service.APIKey) error { return nil } +func (r *stubAPIKeyRepoForHandler) GetKeyAndOwnerID(_ context.Context, _ int64) (string, int64, error) { + return "", 0, nil +} +func (r *stubAPIKeyRepoForHandler) GetByKey(context.Context, string) (*service.APIKey, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*service.APIKey, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil } +func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil } +func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) CountByUserID(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubAPIKeyRepoForHandler) ExistsByKey(context.Context, string) (bool, error) { + return false, nil +} +func (r *stubAPIKeyRepoForHandler) ListByGroupID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubAPIKeyRepoForHandler) SearchAPIKeys(context.Context, int64, string, int) ([]service.APIKey, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubAPIKeyRepoForHandler) CountByGroupID(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubAPIKeyRepoForHandler) ListKeysByUserID(context.Context, int64) ([]string, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) ListKeysByGroupID(context.Context, int64) ([]string, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) IncrementQuotaUsed(_ context.Context, _ int64, _ float64) (float64, error) { + return 0, nil +} +func (r *stubAPIKeyRepoForHandler) UpdateLastUsed(context.Context, int64, time.Time) error { + return nil +} +func (r *stubAPIKeyRepoForHandler) IncrementRateLimitUsage(context.Context, int64, float64) error { + return nil +} +func (r *stubAPIKeyRepoForHandler) ResetRateLimitWindows(context.Context, int64) error { + return nil +} +func (r *stubAPIKeyRepoForHandler) GetRateLimitData(context.Context, int64) (*service.APIKeyRateLimitData, error) { + return nil, nil +} + +// newTestAPIKeyService 创建测试用的 APIKeyService +func newTestAPIKeyService(repo *stubAPIKeyRepoForHandler) *service.APIKeyService { + return service.NewAPIKeyService(repo, nil, nil, nil, nil, nil, &config.Config{}) +} + +// ==================== Generate: API Key 校验(前端传递 api_key_id)==================== + +func TestGenerate_WithAPIKeyID_Success(t *testing.T) { + // 前端传递 api_key_id,校验通过 → 成功生成,记录关联 api_key_id + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + groupID := int64(5) + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyActive, + GroupID: &groupID, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.NotZero(t, data["generation_id"]) + + // 验证 api_key_id 已关联到生成记录 + gen := repo.gens[1] + require.NotNil(t, gen.APIKeyID) + require.Equal(t, int64(42), *gen.APIKeyID) +} + +func TestGenerate_WithAPIKeyID_NotFound(t *testing.T) { + // 前端传递不存在的 api_key_id → 400 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":999}`, 1) + h.Generate(c) + require.Equal(t, http.StatusBadRequest, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, fmt.Sprint(resp["message"]), "不存在") +} + +func TestGenerate_WithAPIKeyID_WrongUser(t *testing.T) { + // 前端传递别人的 api_key_id → 403 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 999, // 属于 user 999 + Status: service.StatusAPIKeyActive, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusForbidden, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, fmt.Sprint(resp["message"]), "不属于") +} + +func TestGenerate_WithAPIKeyID_Disabled(t *testing.T) { + // 前端传递已禁用的 api_key_id → 403 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyDisabled, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusForbidden, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, fmt.Sprint(resp["message"]), "不可用") +} + +func TestGenerate_WithAPIKeyID_QuotaExhausted(t *testing.T) { + // 前端传递配额耗尽的 api_key_id → 403 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyQuotaExhausted, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusForbidden, rec.Code) +} + +func TestGenerate_WithAPIKeyID_Expired(t *testing.T) { + // 前端传递已过期的 api_key_id → 403 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyExpired, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusForbidden, rec.Code) +} + +func TestGenerate_WithAPIKeyID_NilAPIKeyService(t *testing.T) { + // apiKeyService 为 nil 时忽略 api_key_id → 正常生成但不记录 api_key_id + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + h := &SoraClientHandler{genService: genService} // apiKeyService = nil + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + // apiKeyService 为 nil → 跳过校验 → api_key_id 不记录 + require.Nil(t, repo.gens[1].APIKeyID) +} + +func TestGenerate_WithAPIKeyID_NilGroupID(t *testing.T) { + // api_key 有效但 GroupID 为 nil → 成功,groupID 为 nil + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyActive, + GroupID: nil, // 无分组 + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.gens[1].APIKeyID) + require.Equal(t, int64(42), *repo.gens[1].APIKeyID) +} + +func TestGenerate_NoAPIKeyID_NoContext_NilResult(t *testing.T) { + // 既无 api_key_id 字段也无 context 中的 api_key_id → api_key_id 为 nil + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Nil(t, repo.gens[1].APIKeyID) +} + +func TestGenerate_WithAPIKeyIDInBody_OverridesContext(t *testing.T) { + // 同时有 body api_key_id 和 context api_key_id → 优先使用 body 的 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + groupID := int64(10) + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyActive, + GroupID: &groupID, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + c.Set("api_key_id", int64(99)) // context 中有另一个 api_key_id + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + // 应使用 body 中的 api_key_id=42,而不是 context 中的 99 + require.NotNil(t, repo.gens[1].APIKeyID) + require.Equal(t, int64(42), *repo.gens[1].APIKeyID) +} + +func TestGenerate_WithContextAPIKeyID_FallbackPath(t *testing.T) { + // 无 body api_key_id,但 context 有 → 使用 context 中的(兼容网关路由) + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + c.Set("api_key_id", int64(99)) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + // 应使用 context 中的 api_key_id=99 + require.NotNil(t, repo.gens[1].APIKeyID) + require.Equal(t, int64(99), *repo.gens[1].APIKeyID) +} + +func TestGenerate_APIKeyID_Zero_IgnoredInJSON(t *testing.T) { + // JSON 中 api_key_id=0 被视为 omitempty → 仍然为指针值 0,需要传 nil 检查 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + // JSON 中传了 api_key_id: 0 → 解析后 *int64(0),会触发校验 + // api_key_id=0 不存在 → 400 + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":0}`, 1) + h.Generate(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +// ==================== processGeneration: groupID 传递与 ForcePlatform ==================== + +func TestProcessGeneration_WithGroupID_NoForcePlatform(t *testing.T) { + // groupID 不为 nil → 不设置 ForcePlatform + // gatewayService 为 nil → MarkFailed → 检查错误消息不包含 ForcePlatform 相关 + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + gid := int64(5) + h.processGeneration(1, 1, &gid, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") +} + +func TestProcessGeneration_NilGroupID_SetsForcePlatform(t *testing.T) { + // groupID 为 nil → 设置 ForcePlatform → gatewayService 为 nil → MarkFailed + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") +} + +func TestProcessGeneration_MarkGeneratingStateConflict(t *testing.T) { + // 任务状态已变化(如已取消)→ MarkGenerating 返回 ErrSoraGenerationStateConflict → 跳过 + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"} + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + // 状态为 cancelled 时 MarkGenerating 不符合状态转换规则 → 应保持 cancelled + require.Equal(t, "cancelled", repo.gens[1].Status) +} + +// ==================== GenerateRequest JSON 解析 ==================== + +func TestGenerateRequest_WithAPIKeyID_JSONParsing(t *testing.T) { + // 验证 api_key_id 在 JSON 中正确解析为 *int64 + var req GenerateRequest + err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":42}`), &req) + require.NoError(t, err) + require.NotNil(t, req.APIKeyID) + require.Equal(t, int64(42), *req.APIKeyID) +} + +func TestGenerateRequest_WithoutAPIKeyID_JSONParsing(t *testing.T) { + // 不传 api_key_id → 解析后为 nil + var req GenerateRequest + err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test"}`), &req) + require.NoError(t, err) + require.Nil(t, req.APIKeyID) +} + +func TestGenerateRequest_NullAPIKeyID_JSONParsing(t *testing.T) { + // api_key_id: null → 解析后为 nil + var req GenerateRequest + err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":null}`), &req) + require.NoError(t, err) + require.Nil(t, req.APIKeyID) +} + +func TestGenerateRequest_FullFields_JSONParsing(t *testing.T) { + // 全字段解析 + var req GenerateRequest + err := json.Unmarshal([]byte(`{ + "model":"sora2-landscape-10s", + "prompt":"test prompt", + "media_type":"video", + "video_count":2, + "image_input":"data:image/png;base64,abc", + "api_key_id":100 + }`), &req) + require.NoError(t, err) + require.Equal(t, "sora2-landscape-10s", req.Model) + require.Equal(t, "test prompt", req.Prompt) + require.Equal(t, "video", req.MediaType) + require.Equal(t, 2, req.VideoCount) + require.Equal(t, "data:image/png;base64,abc", req.ImageInput) + require.NotNil(t, req.APIKeyID) + require.Equal(t, int64(100), *req.APIKeyID) +} + +func TestGenerateRequest_JSONSerialize_OmitsNilAPIKeyID(t *testing.T) { + // api_key_id 为 nil 时 JSON 序列化应省略 + req := GenerateRequest{Model: "sora2", Prompt: "test"} + b, err := json.Marshal(req) + require.NoError(t, err) + var parsed map[string]any + require.NoError(t, json.Unmarshal(b, &parsed)) + _, hasAPIKeyID := parsed["api_key_id"] + require.False(t, hasAPIKeyID, "api_key_id 为 nil 时应省略") +} + +func TestGenerateRequest_JSONSerialize_IncludesAPIKeyID(t *testing.T) { + // api_key_id 不为 nil 时 JSON 序列化应包含 + id := int64(42) + req := GenerateRequest{Model: "sora2", Prompt: "test", APIKeyID: &id} + b, err := json.Marshal(req) + require.NoError(t, err) + var parsed map[string]any + require.NoError(t, json.Unmarshal(b, &parsed)) + require.Equal(t, float64(42), parsed["api_key_id"]) +} + +// ==================== GetQuota: 有配额服务 ==================== + +func TestGetQuota_WithQuotaService_Success(t *testing.T) { + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 10 * 1024 * 1024, + SoraStorageUsedBytes: 3 * 1024 * 1024, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{ + genService: genService, + quotaService: quotaService, + } + + c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1) + h.GetQuota(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, "user", data["source"]) + require.Equal(t, float64(10*1024*1024), data["quota_bytes"]) + require.Equal(t, float64(3*1024*1024), data["used_bytes"]) +} + +func TestGetQuota_WithQuotaService_Error(t *testing.T) { + // 用户不存在时 GetQuota 返回错误 + userRepo := newStubUserRepoForHandler() + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{ + genService: genService, + quotaService: quotaService, + } + + c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 999) + h.GetQuota(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// ==================== Generate: 配额检查 ==================== + +func TestGenerate_QuotaCheckFailed(t *testing.T) { + // 配额超限时返回 429 + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 1024, + SoraStorageUsedBytes: 1025, // 已超限 + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{ + genService: genService, + quotaService: quotaService, + } + + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusTooManyRequests, rec.Code) +} + +func TestGenerate_QuotaCheckPassed(t *testing.T) { + // 配额充足时允许生成 + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 10 * 1024 * 1024, + SoraStorageUsedBytes: 0, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{ + genService: genService, + quotaService: quotaService, + } + + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +// ==================== Stub: SettingRepository (用于 S3 存储测试) ==================== + +var _ service.SettingRepository = (*stubSettingRepoForHandler)(nil) + +type stubSettingRepoForHandler struct { + values map[string]string +} + +func newStubSettingRepoForHandler(values map[string]string) *stubSettingRepoForHandler { + if values == nil { + values = make(map[string]string) + } + return &stubSettingRepoForHandler{values: values} +} + +func (r *stubSettingRepoForHandler) Get(_ context.Context, key string) (*service.Setting, error) { + if v, ok := r.values[key]; ok { + return &service.Setting{Key: key, Value: v}, nil + } + return nil, service.ErrSettingNotFound +} +func (r *stubSettingRepoForHandler) GetValue(_ context.Context, key string) (string, error) { + if v, ok := r.values[key]; ok { + return v, nil + } + return "", service.ErrSettingNotFound +} +func (r *stubSettingRepoForHandler) Set(_ context.Context, key, value string) error { + r.values[key] = value + return nil +} +func (r *stubSettingRepoForHandler) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + result := make(map[string]string) + for _, k := range keys { + if v, ok := r.values[k]; ok { + result[k] = v + } + } + return result, nil +} +func (r *stubSettingRepoForHandler) SetMultiple(_ context.Context, settings map[string]string) error { + for k, v := range settings { + r.values[k] = v + } + return nil +} +func (r *stubSettingRepoForHandler) GetAll(_ context.Context) (map[string]string, error) { + return r.values, nil +} +func (r *stubSettingRepoForHandler) Delete(_ context.Context, key string) error { + delete(r.values, key) + return nil +} + +// ==================== S3 / MediaStorage 辅助函数 ==================== + +// newS3StorageForHandler 创建指向指定 endpoint 的 S3Storage(用于测试)。 +func newS3StorageForHandler(endpoint string) *service.SoraS3Storage { + settingRepo := newStubSettingRepoForHandler(map[string]string{ + "sora_s3_enabled": "true", + "sora_s3_endpoint": endpoint, + "sora_s3_region": "us-east-1", + "sora_s3_bucket": "test-bucket", + "sora_s3_access_key_id": "AKIATEST", + "sora_s3_secret_access_key": "test-secret", + "sora_s3_prefix": "sora", + "sora_s3_force_path_style": "true", + }) + settingService := service.NewSettingService(settingRepo, &config.Config{}) + return service.NewSoraS3Storage(settingService) +} + +// newFakeSourceServer 创建返回固定内容的 HTTP 服务器(模拟上游媒体文件)。 +func newFakeSourceServer() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "video/mp4") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("fake video data for test")) + })) +} + +// newFakeS3Server 创建模拟 S3 的 HTTP 服务器。 +// mode: "ok" 接受所有请求,"fail" 返回 403,"fail-second" 第一次成功第二次失败。 +func newFakeS3Server(mode string) *httptest.Server { + var counter atomic.Int32 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.Copy(io.Discard, r.Body) + _ = r.Body.Close() + + switch mode { + case "ok": + w.Header().Set("ETag", `"test-etag"`) + w.WriteHeader(http.StatusOK) + case "fail": + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`AccessDenied`)) + case "fail-second": + n := counter.Add(1) + if n <= 1 { + w.Header().Set("ETag", `"test-etag"`) + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`AccessDenied`)) + } + } + })) +} + +// ==================== processGeneration 直接调用测试 ==================== + +func TestProcessGeneration_MarkGeneratingFails(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + repo.updateErr = fmt.Errorf("db error") + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + // 直接调用(非 goroutine),MarkGenerating 失败 → 早退 + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + // MarkGenerating 在调用 repo.Update 前已修改内存对象为 "generating" + // repo.Update 返回错误 → processGeneration 早退,不会继续到 MarkFailed + // 因此 ErrorMessage 为空(证明未调用 MarkFailed) + require.Equal(t, "generating", repo.gens[1].Status) + require.Empty(t, repo.gens[1].ErrorMessage) +} + +func TestProcessGeneration_GatewayServiceNil(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + // gatewayService 未设置 → MarkFailed + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") +} + +// ==================== storeMediaWithDegradation: S3 路径 ==================== + +func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, + ) + require.Equal(t, service.SoraStorageTypeS3, storageType) + require.Len(t, s3Keys, 1) + require.NotEmpty(t, s3Keys[0]) + require.Len(t, storedURLs, 1) + require.Equal(t, storedURL, storedURLs[0]) + require.Contains(t, storedURL, fakeS3.URL) + require.Contains(t, storedURL, "/test-bucket/") + require.Greater(t, fileSize, int64(0)) +} + +func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"} + storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls, + ) + require.Equal(t, service.SoraStorageTypeS3, storageType) + require.Len(t, s3Keys, 2) + require.Len(t, storedURLs, 2) + require.Equal(t, storedURL, storedURLs[0]) + require.Contains(t, storedURLs[0], fakeS3.URL) + require.Contains(t, storedURLs[1], fakeS3.URL) + require.Greater(t, fileSize, int64(0)) +} + +func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) { + // 上游返回 404 → 下载失败 → S3 上传不会开始 + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + badSource := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer badSource.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + _, _, storageType, _, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil, + ) + require.Equal(t, service.SoraStorageTypeUpstream, storageType) +} + +func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("fail") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, + ) + // S3 失败,降级到 upstream + require.Equal(t, service.SoraStorageTypeUpstream, storageType) + require.Nil(t, s3Keys) +} + +func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("fail-second") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"} + _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls, + ) + // 第二个 URL 上传失败 → 清理已上传 → 降级到 upstream + require.Equal(t, service.SoraStorageTypeUpstream, storageType) + require.Nil(t, s3Keys) +} + +// ==================== storeMediaWithDegradation: 本地存储路径 ==================== + +func TestStoreMediaWithDegradation_LocalStorageFails(t *testing.T) { + // 使用无效路径,EnsureLocalDirs 失败 → StoreFromURLs 返回 error + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: "/dev/null/invalid_dir", + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + h := &SoraClientHandler{mediaStorage: mediaStorage} + + _, _, storageType, _, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", "https://upstream.com/v.mp4", nil, + ) + // 本地存储失败,降级到 upstream + require.Equal(t, service.SoraStorageTypeUpstream, storageType) +} + +func TestStoreMediaWithDegradation_LocalStorageSuccess(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-handler-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + DownloadTimeoutSeconds: 5, + MaxDownloadBytes: 10 * 1024 * 1024, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + h := &SoraClientHandler{mediaStorage: mediaStorage} + + _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, + ) + require.Equal(t, service.SoraStorageTypeLocal, storageType) + require.Nil(t, s3Keys) // 本地存储不返回 S3 keys +} + +func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-handler-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("fail") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + DownloadTimeoutSeconds: 5, + MaxDownloadBytes: 10 * 1024 * 1024, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + h := &SoraClientHandler{ + s3Storage: s3Storage, + mediaStorage: mediaStorage, + } + + _, _, storageType, _, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, + ) + // S3 失败 → 本地存储成功 + require.Equal(t, service.SoraStorageTypeLocal, storageType) +} + +// ==================== SaveToStorage: S3 路径 ==================== + +func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("fail") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, resp["message"], "S3") +} + +func TestSaveToStorage_UpstreamURLExpired(t *testing.T) { + expiredServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer expiredServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: expiredServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusGone, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, fmt.Sprint(resp["message"]), "过期") +} + +func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Contains(t, data["message"], "S3") + require.NotEmpty(t, data["object_key"]) + // 验证记录已更新为 S3 存储 + require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) +} + +func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v1.mp4", + MediaURLs: []string{ + sourceServer.URL + "/v1.mp4", + sourceServer.URL + "/v2.mp4", + }, + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Len(t, data["object_keys"].([]any), 2) + require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) + require.Len(t, repo.gens[1].S3ObjectKeys, 2) + require.Len(t, repo.gens[1].MediaURLs, 2) +} + +func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 100 * 1024 * 1024, + SoraStorageUsedBytes: 0, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusOK, rec.Code) + // 验证配额已累加 + require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0)) +} + +func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + // S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败 + repo.updateErr = fmt.Errorf("db error") + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// ==================== GetStorageStatus: S3 路径 ==================== + +func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) { + // S3 启用但 TestConnection 失败(fake 端点不响应 HeadBucket) + fakeS3 := newFakeS3Server("fail") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) + h.GetStorageStatus(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, true, data["s3_enabled"]) + require.Equal(t, false, data["s3_healthy"]) +} + +func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) { + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) + h.GetStorageStatus(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, true, data["s3_enabled"]) + require.Equal(t, true, data["s3_healthy"]) +} + +// ==================== Stub: AccountRepository (用于 GatewayService) ==================== + +var _ service.AccountRepository = (*stubAccountRepoForHandler)(nil) + +type stubAccountRepoForHandler struct { + accounts []service.Account +} + +func (r *stubAccountRepoForHandler) Create(context.Context, *service.Account) error { return nil } +func (r *stubAccountRepoForHandler) GetByID(_ context.Context, id int64) (*service.Account, error) { + for i := range r.accounts { + if r.accounts[i].ID == id { + return &r.accounts[i], nil + } + } + return nil, fmt.Errorf("account not found") +} +func (r *stubAccountRepoForHandler) GetByIDs(context.Context, []int64) ([]*service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) ExistsByID(context.Context, int64) (bool, error) { + return false, nil +} +func (r *stubAccountRepoForHandler) GetByCRSAccountID(context.Context, string) (*service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) FindByExtraField(context.Context, string, any) ([]service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) ListCRSAccountIDs(context.Context) (map[string]int64, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) Update(context.Context, *service.Account) error { return nil } +func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]service.Account, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) ListActive(context.Context) ([]service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) ListByPlatform(context.Context, string) ([]service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) UpdateLastUsed(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error { + return nil +} +func (r *stubAccountRepoForHandler) SetError(context.Context, int64, string) error { return nil } +func (r *stubAccountRepoForHandler) ClearError(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) SetSchedulable(context.Context, int64, bool) error { + return nil +} +func (r *stubAccountRepoForHandler) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) { + return 0, nil +} +func (r *stubAccountRepoForHandler) BindGroups(context.Context, int64, []int64) error { return nil } +func (r *stubAccountRepoForHandler) ListSchedulable(context.Context) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableByGroupID(context.Context, int64) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableByPlatform(_ context.Context, _ string) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context, []string) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatform(_ context.Context, _ string) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableUngroupedByPlatforms(_ context.Context, _ []string) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error { + return nil +} +func (r *stubAccountRepoForHandler) SetModelRateLimit(context.Context, int64, string, time.Time) error { + return nil +} +func (r *stubAccountRepoForHandler) SetOverloaded(context.Context, int64, time.Time) error { + return nil +} +func (r *stubAccountRepoForHandler) SetTempUnschedulable(context.Context, int64, time.Time, string) error { + return nil +} +func (r *stubAccountRepoForHandler) ClearTempUnschedulable(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) ClearRateLimit(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) ClearAntigravityQuotaScopes(context.Context, int64) error { + return nil +} +func (r *stubAccountRepoForHandler) ClearModelRateLimits(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) UpdateSessionWindow(context.Context, int64, *time.Time, *time.Time, string) error { + return nil +} +func (r *stubAccountRepoForHandler) UpdateExtra(context.Context, int64, map[string]any) error { + return nil +} +func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service.AccountBulkUpdate) (int64, error) { + return 0, nil +} + +// ==================== Stub: SoraClient (用于 SoraGatewayService) ==================== + +var _ service.SoraClient = (*stubSoraClientForHandler)(nil) + +type stubSoraClientForHandler struct { + videoStatus *service.SoraVideoTaskStatus +} + +func (s *stubSoraClientForHandler) Enabled() bool { return true } +func (s *stubSoraClientForHandler) UploadImage(context.Context, *service.Account, []byte, string) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) CreateImageTask(context.Context, *service.Account, service.SoraImageRequest) (string, error) { + return "task-image", nil +} +func (s *stubSoraClientForHandler) CreateVideoTask(context.Context, *service.Account, service.SoraVideoRequest) (string, error) { + return "task-video", nil +} +func (s *stubSoraClientForHandler) CreateStoryboardTask(context.Context, *service.Account, service.SoraStoryboardRequest) (string, error) { + return "task-video", nil +} +func (s *stubSoraClientForHandler) UploadCharacterVideo(context.Context, *service.Account, []byte) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) GetCameoStatus(context.Context, *service.Account, string) (*service.SoraCameoStatus, error) { + return nil, nil +} +func (s *stubSoraClientForHandler) DownloadCharacterImage(context.Context, *service.Account, string) ([]byte, error) { + return nil, nil +} +func (s *stubSoraClientForHandler) UploadCharacterImage(context.Context, *service.Account, []byte) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) FinalizeCharacter(context.Context, *service.Account, service.SoraCharacterFinalizeRequest) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) SetCharacterPublic(context.Context, *service.Account, string) error { + return nil +} +func (s *stubSoraClientForHandler) DeleteCharacter(context.Context, *service.Account, string) error { + return nil +} +func (s *stubSoraClientForHandler) PostVideoForWatermarkFree(context.Context, *service.Account, string) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) DeletePost(context.Context, *service.Account, string) error { + return nil +} +func (s *stubSoraClientForHandler) GetWatermarkFreeURLCustom(context.Context, *service.Account, string, string, string) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) EnhancePrompt(context.Context, *service.Account, string, string, int) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) GetImageTask(context.Context, *service.Account, string) (*service.SoraImageTaskStatus, error) { + return nil, nil +} +func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Account, _ string) (*service.SoraVideoTaskStatus, error) { + return s.videoStatus, nil +} + +// ==================== 辅助:创建最小 GatewayService 和 SoraGatewayService ==================== + +// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。 +func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService { + return service.NewGatewayService( + accountRepo, nil, nil, nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, + ) +} + +// newMinimalSoraGatewayService 创建最小 SoraGatewayService(用于测试 Forward)。 +func newMinimalSoraGatewayService(soraClient service.SoraClient) *service.SoraGatewayService { + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + return service.NewSoraGatewayService(soraClient, nil, nil, cfg) +} + +// ==================== processGeneration: 更多路径测试 ==================== + +func TestProcessGeneration_SelectAccountError(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + // accountRepo 返回空列表 → SelectAccountForModel 返回 "no available accounts" + accountRepo := &stubAccountRepoForHandler{accounts: nil} + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{genService: genService, gatewayService: gatewayService} + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败") +} + +func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + // 提供可用账号使 SelectAccountForModel 成功 + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + // soraGatewayService 为 nil + h := &SoraClientHandler{genService: genService, gatewayService: gatewayService} + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "soraGatewayService") +} + +func TestProcessGeneration_ForwardError(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + // SoraClient 返回视频任务失败 + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "failed", + ErrorMsg: "content policy violation", + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "生成失败") +} + +func TestProcessGeneration_ForwardErrorCancelled(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + // MarkGenerating 内部调用 GetByID(第 1 次),Forward 失败后 processGeneration + // 调用 GetByID(第 2 次)。模拟外部在 Forward 期间取消了任务。 + repo.getByIDOverrideAfterN = 1 + repo.getByIDOverrideStatus = "cancelled" + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{Status: "failed", ErrorMsg: "reject"}, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + // Forward 失败后检测到外部取消,不应调用 MarkFailed(状态保持 generating) + require.Equal(t, "generating", repo.gens[1].Status) +} + +func TestProcessGeneration_ForwardSuccessNoMediaURL(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + // SoraClient 返回 completed 但无 URL + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "completed", + URLs: nil, // 无 URL + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "未获取到媒体 URL") +} + +func TestProcessGeneration_ForwardSuccessCancelledBeforeStore(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + // MarkGenerating 调用 GetByID(第 1 次),之后 processGeneration 行 176 调用 GetByID(第 2 次) + // 第 2 次返回 "cancelled" 状态,模拟外部取消 + repo.getByIDOverrideAfterN = 1 + repo.getByIDOverrideStatus = "cancelled" + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/video.mp4"}, + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + // Forward 成功后检测到外部取消,不应调用存储和 MarkCompleted(状态保持 generating) + require.Equal(t, "generating", repo.gens[1].Status) +} + +func TestProcessGeneration_FullSuccessUpstream(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/video.mp4"}, + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + // 无 S3 和本地存储,降级到 upstream + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) + require.Equal(t, "completed", repo.gens[1].Status) + require.Equal(t, service.SoraStorageTypeUpstream, repo.gens[1].StorageType) + require.NotEmpty(t, repo.gens[1].MediaURL) +} + +func TestProcessGeneration_FullSuccessWithS3(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{sourceServer.URL + "/video.mp4"}, + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + s3Storage := newS3StorageForHandler(fakeS3.URL) + + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, SoraStorageQuotaBytes: 100 * 1024 * 1024, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + s3Storage: s3Storage, + quotaService: quotaService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) + require.Equal(t, "completed", repo.gens[1].Status) + require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) + require.NotEmpty(t, repo.gens[1].S3ObjectKeys) + require.Greater(t, repo.gens[1].FileSizeBytes, int64(0)) + // 验证配额已累加 + require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0)) +} + +func TestProcessGeneration_MarkCompletedFails(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + // 第 1 次 Update(MarkGenerating)成功,第 2 次(MarkCompleted)失败 + repo.updateCallCount = new(int32) + repo.updateFailAfterN = 1 + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/video.mp4"}, + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) + // MarkCompleted 内部先修改内存对象状态为 completed,然后 Update 失败。 + // 由于 stub 存储的是指针,内存中的状态已被修改为 completed。 + // 此测试验证 processGeneration 在 MarkCompleted 失败后提前返回(不调用 AddUsage)。 + require.Equal(t, "completed", repo.gens[1].Status) +} + +// ==================== cleanupStoredMedia 直接测试 ==================== + +func TestCleanupStoredMedia_S3Path(t *testing.T) { + // S3 清理路径:s3Storage 为 nil 时不 panic + h := &SoraClientHandler{} + // 不应 panic + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil) +} + +func TestCleanupStoredMedia_LocalPath(t *testing.T) { + // 本地清理路径:mediaStorage 为 nil 时不 panic + h := &SoraClientHandler{} + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"/tmp/test.mp4"}) +} + +func TestCleanupStoredMedia_UpstreamPath(t *testing.T) { + // upstream 类型不清理 + h := &SoraClientHandler{} + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeUpstream, nil, nil) +} + +func TestCleanupStoredMedia_EmptyKeys(t *testing.T) { + // 空 keys 不触发清理 + h := &SoraClientHandler{} + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, nil, nil) + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, nil) +} + +// ==================== DeleteGeneration: 本地存储清理路径 ==================== + +func TestDeleteGeneration_LocalStorageCleanup(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-delete-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, + UserID: 1, + Status: "completed", + StorageType: service.SoraStorageTypeLocal, + MediaURL: "video/test.mp4", + MediaURLs: []string{"video/test.mp4"}, + } + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} + + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) + _, exists := repo.gens[1] + require.False(t, exists) +} + +func TestDeleteGeneration_LocalStorageCleanup_MediaURLFallback(t *testing.T) { + // MediaURLs 为空,使用 MediaURL 作为清理路径 + tmpDir, err := os.MkdirTemp("", "sora-delete-fallback-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, + UserID: 1, + Status: "completed", + StorageType: service.SoraStorageTypeLocal, + MediaURL: "video/test.mp4", + MediaURLs: nil, // 空 + } + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} + + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +func TestDeleteGeneration_NonLocalStorage_SkipCleanup(t *testing.T) { + // 非本地存储类型 → 跳过清理 + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, + UserID: 1, + Status: "completed", + StorageType: service.SoraStorageTypeUpstream, + MediaURL: "https://upstream.com/v.mp4", + } + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +func TestDeleteGeneration_DeleteError(t *testing.T) { + // repo.Delete 出错 + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream"} + repo.deleteErr = fmt.Errorf("delete failed") + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +// ==================== fetchUpstreamModels 测试 ==================== + +func TestFetchUpstreamModels_NilGateway(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + h := &SoraClientHandler{} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "gatewayService 未初始化") +} + +func TestFetchUpstreamModels_NoAccounts(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + accountRepo := &stubAccountRepoForHandler{accounts: nil} + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "选择 Sora 账号失败") +} + +func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: "oauth", Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "不支持模型同步") +} + +func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"base_url": "https://sora.test"}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "api_key") +} + +func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + // GetBaseURL() 在缺少 base_url 时返回默认值 "https://api.anthropic.com" + // 因此不会触发 "账号缺少 base_url" 错误,而是会尝试请求默认 URL 并失败 + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test"}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) +} + +func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "状态码 500") +} + +func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("not json")) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "解析响应失败") +} + +func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[]}`)) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "空模型列表") +} + +func TestFetchUpstreamModels_Success(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 验证请求头 + require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization")) + require.True(t, strings.HasSuffix(r.URL.Path, "/sora/v1/models")) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"sora2-portrait-10s"},{"id":"sora2-landscape-15s"},{"id":"gpt-image"}]}`)) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + families, err := h.fetchUpstreamModels(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, families) +} + +func TestFetchUpstreamModels_UnrecognizedModels(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[{"id":"unknown-model-1"},{"id":"unknown-model-2"}]}`)) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "未能从上游模型列表中识别") +} + +// ==================== getModelFamilies 缓存测试 ==================== + +func TestGetModelFamilies_CachesLocalConfig(t *testing.T) { + // gatewayService 为 nil → fetchUpstreamModels 失败 → 降级到本地配置 + h := &SoraClientHandler{} + families := h.getModelFamilies(context.Background()) + require.NotEmpty(t, families) + + // 第二次调用应命中缓存(modelCacheUpstream=false → 使用短 TTL) + families2 := h.getModelFamilies(context.Background()) + require.Equal(t, families, families2) + require.False(t, h.modelCacheUpstream) +} + +func TestGetModelFamilies_CachesUpstreamResult(t *testing.T) { + t.Skip("TODO: 临时屏蔽依赖 Sora 上游模型同步的缓存测试,待账号选择逻辑稳定后恢复") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"gpt-image"}]}`)) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + + families := h.getModelFamilies(context.Background()) + require.NotEmpty(t, families) + require.True(t, h.modelCacheUpstream) + + // 第二次调用命中缓存 + families2 := h.getModelFamilies(context.Background()) + require.Equal(t, families, families2) +} + +func TestGetModelFamilies_ExpiredCacheRefreshes(t *testing.T) { + // 预设过期的缓存(modelCacheUpstream=false → 短 TTL) + h := &SoraClientHandler{ + cachedFamilies: []service.SoraModelFamily{{ID: "old"}}, + modelCacheTime: time.Now().Add(-10 * time.Minute), // 已过期 + modelCacheUpstream: false, + } + // gatewayService 为 nil → fetchUpstreamModels 失败 → 使用本地配置刷新缓存 + families := h.getModelFamilies(context.Background()) + require.NotEmpty(t, families) + // 缓存已刷新,不再是 "old" + found := false + for _, f := range families { + if f.ID == "old" { + found = true + } + } + require.False(t, found, "过期缓存应被刷新") +} + +// ==================== processGeneration: groupID 与 ForcePlatform ==================== + +func TestProcessGeneration_NilGroupID_WithGateway_SelectAccountFails(t *testing.T) { + // groupID 为 nil → 设置 ForcePlatform=sora → 无可用 sora 账号 → MarkFailed + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + + // 空账号列表 → SelectAccountForModel 失败 + accountRepo := &stubAccountRepoForHandler{accounts: nil} + gatewayService := newMinimalGatewayService(accountRepo) + + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败") +} + +// ==================== Generate: 配额检查非 QuotaExceeded 错误 ==================== + +func TestGenerate_CheckQuotaNonQuotaError(t *testing.T) { + // quotaService.CheckQuota 返回非 QuotaExceededError → 返回 403 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + // 用户不存在 → GetByID 失败 → CheckQuota 返回普通 error + userRepo := newStubUserRepoForHandler() + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + h := NewSoraClientHandler(genService, quotaService, nil, nil, nil, nil, nil) + + body := `{"model":"sora2-landscape-10s","prompt":"test"}` + c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1) + h.Generate(c) + require.Equal(t, http.StatusForbidden, rec.Code) +} + +// ==================== Generate: CreatePending 并发限制错误 ==================== + +// stubSoraGenRepoWithAtomicCreate 实现 soraGenerationRepoAtomicCreator 接口 +type stubSoraGenRepoWithAtomicCreate struct { + stubSoraGenRepo + limitErr error +} + +func (r *stubSoraGenRepoWithAtomicCreate) CreatePendingWithLimit(_ context.Context, gen *service.SoraGeneration, _ []string, _ int64) error { + if r.limitErr != nil { + return r.limitErr + } + return r.stubSoraGenRepo.Create(context.Background(), gen) +} + +func TestGenerate_CreatePendingConcurrencyLimit(t *testing.T) { + repo := &stubSoraGenRepoWithAtomicCreate{ + stubSoraGenRepo: *newStubSoraGenRepo(), + limitErr: service.ErrSoraGenerationConcurrencyLimit, + } + genService := service.NewSoraGenerationService(repo, nil, nil) + h := NewSoraClientHandler(genService, nil, nil, nil, nil, nil, nil) + + body := `{"model":"sora2-landscape-10s","prompt":"test"}` + c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1) + h.Generate(c) + require.Equal(t, http.StatusTooManyRequests, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, resp["message"], "3") +} + +// ==================== SaveToStorage: 配额超限 ==================== + +func TestSaveToStorage_QuotaExceeded(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + + // 用户配额已满 + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 10, + SoraStorageUsedBytes: 10, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusTooManyRequests, rec.Code) +} + +// ==================== SaveToStorage: 配额非 QuotaExceeded 错误 ==================== + +func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + + // 用户不存在 → GetByID 失败 → AddUsage 返回普通 error + userRepo := newStubUserRepoForHandler() + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// ==================== SaveToStorage: MediaURLs 全为空 ==================== + +func TestSaveToStorage_EmptyMediaURLs(t *testing.T) { + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: "", + MediaURLs: []string{}, + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusBadRequest, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, resp["message"], "已过期") +} + +// ==================== SaveToStorage: S3 上传失败时已有已上传文件需清理 ==================== + +func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("fail-second") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v1.mp4", + MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"}, + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// ==================== SaveToStorage: UpdateStorageForCompleted 失败(含配额回滚) ==================== + +func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + repo.updateErr = fmt.Errorf("db error") + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 100 * 1024 * 1024, + SoraStorageUsedBytes: 0, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// ==================== cleanupStoredMedia: 实际 S3 删除路径 ==================== + +func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) { + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil) +} + +func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) { + fakeS3 := newFakeS3Server("fail") + defer fakeS3.Close() + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil) +} + +func TestCleanupStoredMedia_LocalDeleteFails_LogOnly(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-cleanup-fail-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + h := &SoraClientHandler{mediaStorage: mediaStorage} + + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"nonexistent/file.mp4"}) +} + +// ==================== DeleteGeneration: 本地文件删除失败(仅日志) ==================== + +func TestDeleteGeneration_LocalStorageDeleteFails_LogOnly(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-del-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: service.SoraStorageTypeLocal, + MediaURL: "nonexistent/video.mp4", + MediaURLs: []string{"nonexistent/video.mp4"}, + } + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} + + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +// ==================== CancelGeneration: 任务已结束冲突 ==================== + +func TestCancelGeneration_AlreadyCompleted(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusConflict, rec.Code) +} diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go new file mode 100644 index 00000000..48c1e451 --- /dev/null +++ b/backend/internal/handler/sora_gateway_handler.go @@ -0,0 +1,685 @@ +package handler + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "net/http" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/Wei-Shaw/sub2api/internal/util/soraerror" + + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.uber.org/zap" +) + +// SoraGatewayHandler handles Sora chat completions requests +type SoraGatewayHandler struct { + gatewayService *service.GatewayService + soraGatewayService *service.SoraGatewayService + billingCacheService *service.BillingCacheService + usageRecordWorkerPool *service.UsageRecordWorkerPool + concurrencyHelper *ConcurrencyHelper + maxAccountSwitches int + streamMode string + soraTLSEnabled bool + soraMediaSigningKey string + soraMediaRoot string +} + +// NewSoraGatewayHandler creates a new SoraGatewayHandler +func NewSoraGatewayHandler( + gatewayService *service.GatewayService, + soraGatewayService *service.SoraGatewayService, + concurrencyService *service.ConcurrencyService, + billingCacheService *service.BillingCacheService, + usageRecordWorkerPool *service.UsageRecordWorkerPool, + cfg *config.Config, +) *SoraGatewayHandler { + pingInterval := time.Duration(0) + maxAccountSwitches := 3 + streamMode := "force" + soraTLSEnabled := true + signKey := "" + mediaRoot := "/app/data/sora" + if cfg != nil { + pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second + if cfg.Gateway.MaxAccountSwitches > 0 { + maxAccountSwitches = cfg.Gateway.MaxAccountSwitches + } + if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" { + streamMode = mode + } + soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint + signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey) + if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" { + mediaRoot = root + } + } + return &SoraGatewayHandler{ + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + billingCacheService: billingCacheService, + usageRecordWorkerPool: usageRecordWorkerPool, + concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), + maxAccountSwitches: maxAccountSwitches, + streamMode: strings.ToLower(streamMode), + soraTLSEnabled: soraTLSEnabled, + soraMediaSigningKey: signKey, + soraMediaRoot: mediaRoot, + } +} + +// ChatCompletions handles Sora /v1/chat/completions endpoint +func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + reqLog := requestLogger( + c, + "handler.sora_gateway.chat_completions", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + if len(body) == 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + + setOpsRequestContext(c, "", false, body) + + // 校验请求体 JSON 合法性 + if !gjson.ValidBytes(body) { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + + // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + reqModel := modelResult.String() + + msgsResult := gjson.GetBytes(body, "messages") + if !msgsResult.IsArray() || len(msgsResult.Array()) == 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required") + return + } + + clientStream := gjson.GetBytes(body, "stream").Bool() + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", clientStream)) + if !clientStream { + if h.streamMode == "error" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true") + return + } + var err error + body, err = sjson.SetBytes(body, "stream", true) + if err != nil { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request") + return + } + } + + setOpsRequestContext(c, reqModel, clientStream, body) + + platform := "" + if forced, ok := middleware2.GetForcePlatformFromContext(c); ok { + platform = forced + } else if apiKey.Group != nil { + platform = apiKey.Group.Platform + } + if platform != service.PlatformSora { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "This endpoint only supports Sora platform") + return + } + + streamStarted := false + subscription, _ := middleware2.GetSubscriptionFromContext(c) + + maxWait := service.CalculateMaxWait(subject.Concurrency) + canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) + waitCounted := false + if err != nil { + reqLog.Warn("sora.user_wait_counter_increment_failed", zap.Error(err)) + } else if !canWait { + reqLog.Info("sora.user_wait_queue_full", zap.Int("max_wait", maxWait)) + h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") + return + } + if err == nil && canWait { + waitCounted = true + } + defer func() { + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + } + }() + + userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted) + if err != nil { + reqLog.Warn("sora.user_slot_acquire_failed", zap.Error(err)) + h.handleConcurrencyError(c, err, "user", streamStarted) + return + } + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + waitCounted = false + } + userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("sora.billing_eligibility_check_failed", zap.Error(err)) + status, code, message := billingErrorDetails(err) + h.handleStreamingAwareError(c, status, code, message, streamStarted) + return + } + + sessionHash := generateOpenAISessionHash(c, body) + + maxAccountSwitches := h.maxAccountSwitches + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + lastFailoverStatus := 0 + var lastFailoverBody []byte + var lastFailoverHeaders http.Header + + for { + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "") + if err != nil { + reqLog.Warn("sora.account_select_failed", + zap.Error(err), + zap.Int("excluded_account_count", len(failedAccountIDs)), + ) + if len(failedAccountIDs) == 0 { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) + return + } + rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody) + fields := []zap.Field{ + zap.Int("last_upstream_status", lastFailoverStatus), + } + if rayID != "" { + fields = append(fields, zap.String("last_upstream_cf_ray", rayID)) + } + if mitigated != "" { + fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated)) + } + if contentType != "" { + fields = append(fields, zap.String("last_upstream_content_type", contentType)) + } + reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...) + h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted) + return + } + account := selection.Account + setOpsSelectedAccount(c, account.ID, account.Platform) + proxyBound := account.ProxyID != nil + proxyID := int64(0) + if account.ProxyID != nil { + proxyID = *account.ProxyID + } + tlsFingerprintEnabled := h.soraTLSEnabled + + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + accountWaitCounted := false + canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + if err != nil { + reqLog.Warn("sora.account_wait_counter_increment_failed", + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Error(err), + ) + } else if !canWait { + reqLog.Info("sora.account_wait_queue_full", + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), + ) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) + return + } + if err == nil && canWait { + accountWaitCounted = true + } + defer func() { + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } + }() + + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + clientStream, + &streamStarted, + ) + if err != nil { + reqLog.Warn("sora.account_slot_acquire_failed", + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Error(err), + ) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false + } + } + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + + result, err := h.soraGatewayService.Forward(c.Request.Context(), c, account, body, clientStream) + if accountReleaseFunc != nil { + accountReleaseFunc() + } + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + failedAccountIDs[account.ID] = struct{}{} + if switchCount >= maxAccountSwitches { + lastFailoverStatus = failoverErr.StatusCode + lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders) + lastFailoverBody = failoverErr.ResponseBody + rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody) + fields := []zap.Field{ + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + } + if rayID != "" { + fields = append(fields, zap.String("upstream_cf_ray", rayID)) + } + if mitigated != "" { + fields = append(fields, zap.String("upstream_cf_mitigated", mitigated)) + } + if contentType != "" { + fields = append(fields, zap.String("upstream_content_type", contentType)) + } + reqLog.Warn("sora.upstream_failover_exhausted", fields...) + h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted) + return + } + lastFailoverStatus = failoverErr.StatusCode + lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders) + lastFailoverBody = failoverErr.ResponseBody + switchCount++ + upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody) + rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody) + fields := []zap.Field{ + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.String("upstream_error_code", upstreamErrCode), + zap.String("upstream_error_message", upstreamErrMsg), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + } + if rayID != "" { + fields = append(fields, zap.String("upstream_cf_ray", rayID)) + } + if mitigated != "" { + fields = append(fields, zap.String("upstream_cf_mitigated", mitigated)) + } + if contentType != "" { + fields = append(fields, zap.String("upstream_content_type", contentType)) + } + reqLog.Warn("sora.upstream_failover_switching", fields...) + continue + } + reqLog.Error("sora.forward_failed", + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Error(err), + ) + return + } + + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + + // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + h.submitUsageRecordTask(func(ctx context.Context) { + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + }); err != nil { + logger.L().With( + zap.String("component", "handler.sora_gateway.chat_completions"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", reqModel), + zap.Int64("account_id", account.ID), + ).Error("sora.record_usage_failed", zap.Error(err)) + } + }) + reqLog.Debug("sora.request_completed", + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Int("switch_count", switchCount), + ) + return + } +} + +func generateOpenAISessionHash(c *gin.Context, body []byte) string { + if c == nil { + return "" + } + sessionID := strings.TrimSpace(c.GetHeader("session_id")) + if sessionID == "" { + sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) + } + if sessionID == "" && len(body) > 0 { + sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + } + if sessionID == "" { + return "" + } + hash := sha256.Sum256([]byte(sessionID)) + return hex.EncodeToString(hash[:]) +} + +func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { + if task == nil { + return + } + if h.usageRecordWorkerPool != nil { + h.usageRecordWorkerPool.Submit(task) + return + } + // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + defer func() { + if recovered := recover(); recovered != nil { + logger.L().With( + zap.String("component", "handler.sora_gateway.chat_completions"), + zap.Any("panic", recovered), + ).Error("sora.usage_record_task_panic_recovered") + } + }() + task(ctx) +} + +func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", + fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) +} + +func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) { + status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody) + h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) +} + +func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) { + if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) { + baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode) + return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody) + } + + upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody) + if strings.EqualFold(upstreamCode, "cf_shield_429") { + baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry." + return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody) + } + if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) { + switch statusCode { + case 401, 403, 404, 500, 502, 503, 504: + return http.StatusBadGateway, "upstream_error", upstreamMessage + case 429: + return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage + } + } + + switch statusCode { + case 401: + return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator" + case 403: + return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator" + case 404: + if strings.EqualFold(upstreamCode, "unsupported_country_code") { + return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator" + } + return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator" + case 429: + return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later" + case 529: + return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later" + case 500, 502, 503, 504: + return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable" + default: + return http.StatusBadGateway, "upstream_error", "Upstream request failed" + } +} + +func cloneHTTPHeaders(headers http.Header) http.Header { + if headers == nil { + return nil + } + return headers.Clone() +} + +func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) { + if headers != nil { + mitigated = strings.TrimSpace(headers.Get("cf-mitigated")) + contentType = strings.TrimSpace(headers.Get("content-type")) + if contentType == "" { + contentType = strings.TrimSpace(headers.Get("Content-Type")) + } + } + rayID = soraerror.ExtractCloudflareRayID(headers, body) + return rayID, mitigated, contentType +} + +func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool { + return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body) +} + +func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool { + message = strings.TrimSpace(message) + if message == "" { + return false + } + if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests { + lower := strings.ToLower(message) + if strings.Contains(lower, "Just a moment...`) + + h := &SoraGatewayHandler{} + h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true) + + lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n") + require.Len(t, lines, 2) + jsonStr := strings.TrimPrefix(lines[1], "data: ") + + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed)) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "upstream_error", errorObj["type"]) + msg, _ := errorObj["message"].(string) + require.Contains(t, msg, "Cloudflare challenge") + require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA") +} + +func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + headers := http.Header{} + headers.Set("cf-ray", "9d03b68c086027a1-SEA") + body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`) + + h := &SoraGatewayHandler{} + h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true) + + lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n") + require.Len(t, lines, 2) + jsonStr := strings.TrimPrefix(lines[1], "data: ") + + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed)) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "rate_limit_error", errorObj["type"]) + msg, _ := errorObj["message"].(string) + require.Contains(t, msg, "Cloudflare shield") + require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA") +} + +func TestExtractSoraFailoverHeaderInsights(t *testing.T) { + headers := http.Header{} + headers.Set("cf-mitigated", "challenge") + headers.Set("content-type", "text/html") + body := []byte(``) + + rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body) + require.Equal(t, "9cff2d62d83bb98d", rayID) + require.Equal(t, "challenge", mitigated) + require.Equal(t, "text/html", contentType) +} diff --git a/backend/internal/handler/usage_handler.go b/backend/internal/handler/usage_handler.go index 129dbfa6..2bd0e0d7 100644 --- a/backend/internal/handler/usage_handler.go +++ b/backend/internal/handler/usage_handler.go @@ -2,6 +2,7 @@ package handler import ( "strconv" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" @@ -65,8 +66,17 @@ func (h *UsageHandler) List(c *gin.Context) { // Parse additional filters model := c.Query("model") + var requestType *int16 var stream *bool - if streamStr := c.Query("stream"); streamStr != "" { + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { val, err := strconv.ParseBool(streamStr) if err != nil { response.BadRequest(c, "Invalid stream value, use true or false") @@ -114,6 +124,7 @@ func (h *UsageHandler) List(c *gin.Context) { UserID: subject.UserID, // Always filter by current user for security APIKeyID: apiKeyID, Model: model, + RequestType: requestType, Stream: stream, BillingType: billingType, StartTime: startTime, @@ -392,7 +403,7 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) { return } - stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs) + stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs, time.Time{}, time.Time{}) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/usage_handler_request_type_test.go b/backend/internal/handler/usage_handler_request_type_test.go new file mode 100644 index 00000000..7c4c7913 --- /dev/null +++ b/backend/internal/handler/usage_handler_request_type_test.go @@ -0,0 +1,80 @@ +package handler + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type userUsageRepoCapture struct { + service.UsageLogRepository + listFilters usagestats.UsageLogFilters +} + +func (s *userUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { + s.listFilters = filters + return []service.UsageLog{}, &pagination.PaginationResult{ + Total: 0, + Page: params.Page, + PageSize: params.PageSize, + Pages: 0, + }, nil +} + +func newUserUsageRequestTypeTestRouter(repo *userUsageRepoCapture) *gin.Engine { + gin.SetMode(gin.TestMode) + usageSvc := service.NewUsageService(repo, nil, nil, nil) + handler := NewUsageHandler(usageSvc, nil) + router := gin.New() + router.Use(func(c *gin.Context) { + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42}) + c.Next() + }) + router.GET("/usage", handler.List) + return router +} + +func TestUserUsageListRequestTypePriority(t *testing.T) { + repo := &userUsageRepoCapture{} + router := newUserUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/usage?request_type=ws_v2&stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, int64(42), repo.listFilters.UserID) + require.NotNil(t, repo.listFilters.RequestType) + require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType) + require.Nil(t, repo.listFilters.Stream) +} + +func TestUserUsageListInvalidRequestType(t *testing.T) { + repo := &userUsageRepoCapture{} + router := newUserUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/usage?request_type=invalid", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestUserUsageListInvalidStream(t *testing.T) { + repo := &userUsageRepoCapture{} + router := newUserUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/usage?stream=invalid", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} diff --git a/backend/internal/handler/usage_record_submit_task_test.go b/backend/internal/handler/usage_record_submit_task_test.go new file mode 100644 index 00000000..c7c48e14 --- /dev/null +++ b/backend/internal/handler/usage_record_submit_task_test.go @@ -0,0 +1,184 @@ +package handler + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func newUsageRecordTestPool(t *testing.T) *service.UsageRecordWorkerPool { + t.Helper() + pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{ + WorkerCount: 1, + QueueSize: 8, + TaskTimeout: time.Second, + OverflowPolicy: "drop", + OverflowSamplePercent: 0, + AutoScaleEnabled: false, + }) + t.Cleanup(pool.Stop) + return pool +} + +func TestGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { + pool := newUsageRecordTestPool(t) + h := &GatewayHandler{usageRecordWorkerPool: pool} + + done := make(chan struct{}) + h.submitUsageRecordTask(func(ctx context.Context) { + close(done) + }) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("task not executed") + } +} + +func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) { + h := &GatewayHandler{} + var called atomic.Bool + + h.submitUsageRecordTask(func(ctx context.Context) { + if _, ok := ctx.Deadline(); !ok { + t.Fatal("expected deadline in fallback context") + } + called.Store(true) + }) + + require.True(t, called.Load()) +} + +func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { + h := &GatewayHandler{} + require.NotPanics(t, func() { + h.submitUsageRecordTask(nil) + }) +} + +func TestGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) { + h := &GatewayHandler{} + var called atomic.Bool + + require.NotPanics(t, func() { + h.submitUsageRecordTask(func(ctx context.Context) { + panic("usage task panic") + }) + }) + + h.submitUsageRecordTask(func(ctx context.Context) { + called.Store(true) + }) + require.True(t, called.Load(), "panic 后后续任务应仍可执行") +} + +func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { + pool := newUsageRecordTestPool(t) + h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool} + + done := make(chan struct{}) + h.submitUsageRecordTask(func(ctx context.Context) { + close(done) + }) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("task not executed") + } +} + +func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) { + h := &OpenAIGatewayHandler{} + var called atomic.Bool + + h.submitUsageRecordTask(func(ctx context.Context) { + if _, ok := ctx.Deadline(); !ok { + t.Fatal("expected deadline in fallback context") + } + called.Store(true) + }) + + require.True(t, called.Load()) +} + +func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { + h := &OpenAIGatewayHandler{} + require.NotPanics(t, func() { + h.submitUsageRecordTask(nil) + }) +} + +func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) { + h := &OpenAIGatewayHandler{} + var called atomic.Bool + + require.NotPanics(t, func() { + h.submitUsageRecordTask(func(ctx context.Context) { + panic("usage task panic") + }) + }) + + h.submitUsageRecordTask(func(ctx context.Context) { + called.Store(true) + }) + require.True(t, called.Load(), "panic 后后续任务应仍可执行") +} + +func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { + pool := newUsageRecordTestPool(t) + h := &SoraGatewayHandler{usageRecordWorkerPool: pool} + + done := make(chan struct{}) + h.submitUsageRecordTask(func(ctx context.Context) { + close(done) + }) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("task not executed") + } +} + +func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) { + h := &SoraGatewayHandler{} + var called atomic.Bool + + h.submitUsageRecordTask(func(ctx context.Context) { + if _, ok := ctx.Deadline(); !ok { + t.Fatal("expected deadline in fallback context") + } + called.Store(true) + }) + + require.True(t, called.Load()) +} + +func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { + h := &SoraGatewayHandler{} + require.NotPanics(t, func() { + h.submitUsageRecordTask(nil) + }) +} + +func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) { + h := &SoraGatewayHandler{} + var called atomic.Bool + + require.NotPanics(t, func() { + h.submitUsageRecordTask(func(ctx context.Context) { + panic("usage task panic") + }) + }) + + h.submitUsageRecordTask(func(ctx context.Context) { + called.Store(true) + }) + require.True(t, called.Load(), "panic 后后续任务应仍可执行") +} diff --git a/backend/internal/handler/user_msg_queue_helper.go b/backend/internal/handler/user_msg_queue_helper.go new file mode 100644 index 00000000..50449b13 --- /dev/null +++ b/backend/internal/handler/user_msg_queue_helper.go @@ -0,0 +1,237 @@ +package handler + +import ( + "context" + "fmt" + "net/http" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// UserMsgQueueHelper 用户消息串行队列 Handler 层辅助 +// 复用 ConcurrencyHelper 的退避 + SSE ping 模式 +type UserMsgQueueHelper struct { + queueService *service.UserMessageQueueService + pingFormat SSEPingFormat + pingInterval time.Duration +} + +// NewUserMsgQueueHelper 创建用户消息串行队列辅助 +func NewUserMsgQueueHelper( + queueService *service.UserMessageQueueService, + pingFormat SSEPingFormat, + pingInterval time.Duration, +) *UserMsgQueueHelper { + if pingInterval <= 0 { + pingInterval = defaultPingInterval + } + return &UserMsgQueueHelper{ + queueService: queueService, + pingFormat: pingFormat, + pingInterval: pingInterval, + } +} + +// AcquireWithWait 等待获取串行锁,流式请求期间发送 SSE ping +// 返回的 releaseFunc 内部使用 sync.Once,确保只执行一次释放 +func (h *UserMsgQueueHelper) AcquireWithWait( + c *gin.Context, + accountID int64, + baseRPM int, + isStream bool, + streamStarted *bool, + timeout time.Duration, + reqLog *zap.Logger, +) (releaseFunc func(), err error) { + ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) + defer cancel() + + // 先尝试立即获取 + result, err := h.queueService.TryAcquire(ctx, accountID) + if err != nil { + return nil, err // fail-open 已在 service 层处理 + } + + if result.Acquired { + // 获取成功,执行 RPM 自适应延迟 + if err := h.queueService.EnforceDelay(ctx, accountID, baseRPM); err != nil { + if ctx.Err() != nil { + // 延迟期间 context 取消,释放锁 + bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second) + _ = h.queueService.Release(bgCtx, accountID, result.RequestID) + bgCancel() + return nil, ctx.Err() + } + } + reqLog.Debug("gateway.umq_lock_acquired", zap.Int64("account_id", accountID)) + return h.makeReleaseFunc(accountID, result.RequestID, reqLog), nil + } + + // 需要等待:指数退避轮询 + return h.waitForLockWithPing(c, ctx, accountID, baseRPM, isStream, streamStarted, reqLog) +} + +// waitForLockWithPing 等待获取锁,流式请求期间发送 SSE ping +func (h *UserMsgQueueHelper) waitForLockWithPing( + c *gin.Context, + ctx context.Context, + accountID int64, + baseRPM int, + isStream bool, + streamStarted *bool, + reqLog *zap.Logger, +) (func(), error) { + needPing := isStream && h.pingFormat != "" + + var flusher http.Flusher + if needPing { + var ok bool + flusher, ok = c.Writer.(http.Flusher) + if !ok { + needPing = false + } + } + + var pingCh <-chan time.Time + if needPing { + pingTicker := time.NewTicker(h.pingInterval) + defer pingTicker.Stop() + pingCh = pingTicker.C + } + + backoff := initialBackoff + timer := time.NewTimer(backoff) + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("umq wait timeout for account %d", accountID) + + case <-pingCh: + if !*streamStarted { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + *streamStarted = true + } + if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil { + return nil, err + } + flusher.Flush() + + case <-timer.C: + result, err := h.queueService.TryAcquire(ctx, accountID) + if err != nil { + return nil, err + } + if result.Acquired { + // 获取成功,执行 RPM 自适应延迟 + if delayErr := h.queueService.EnforceDelay(ctx, accountID, baseRPM); delayErr != nil { + if ctx.Err() != nil { + bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second) + _ = h.queueService.Release(bgCtx, accountID, result.RequestID) + bgCancel() + return nil, ctx.Err() + } + } + reqLog.Debug("gateway.umq_lock_acquired", zap.Int64("account_id", accountID)) + return h.makeReleaseFunc(accountID, result.RequestID, reqLog), nil + } + backoff = nextBackoff(backoff) + timer.Reset(backoff) + } + } +} + +// makeReleaseFunc 创建锁释放函数(使用 sync.Once 确保只执行一次) +func (h *UserMsgQueueHelper) makeReleaseFunc(accountID int64, requestID string, reqLog *zap.Logger) func() { + var once sync.Once + return func() { + once.Do(func() { + bgCtx, bgCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer bgCancel() + if err := h.queueService.Release(bgCtx, accountID, requestID); err != nil { + reqLog.Warn("gateway.umq_release_failed", + zap.Int64("account_id", accountID), + zap.Error(err), + ) + } else { + reqLog.Debug("gateway.umq_lock_released", zap.Int64("account_id", accountID)) + } + }) + } +} + +// ThrottleWithPing 软性限速模式:施加 RPM 自适应延迟,流式期间发送 SSE ping +// 不获取串行锁,不阻塞并发。返回后即可转发请求。 +func (h *UserMsgQueueHelper) ThrottleWithPing( + c *gin.Context, + accountID int64, + baseRPM int, + isStream bool, + streamStarted *bool, + timeout time.Duration, + reqLog *zap.Logger, +) error { + ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) + defer cancel() + + delay := h.queueService.CalculateRPMAwareDelay(ctx, accountID, baseRPM) + if delay <= 0 { + return nil + } + + reqLog.Debug("gateway.umq_throttle_delay", + zap.Int64("account_id", accountID), + zap.Duration("delay", delay), + ) + + // 延迟期间发送 SSE ping(复用 waitForLockWithPing 的 ping 逻辑) + needPing := isStream && h.pingFormat != "" + var flusher http.Flusher + if needPing { + flusher, _ = c.Writer.(http.Flusher) + if flusher == nil { + needPing = false + } + } + + var pingCh <-chan time.Time + if needPing { + pingTicker := time.NewTicker(h.pingInterval) + defer pingTicker.Stop() + pingCh = pingTicker.C + } + + timer := time.NewTimer(delay) + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-pingCh: + // SSE ping 逻辑(与 waitForLockWithPing 一致) + if !*streamStarted { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + *streamStarted = true + } + if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil { + return err + } + flusher.Flush() + case <-timer.C: + return nil + } + } +} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 7b62149c..76f5a979 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -14,6 +14,7 @@ func ProvideAdminHandlers( groupHandler *admin.GroupHandler, accountHandler *admin.AccountHandler, announcementHandler *admin.AnnouncementHandler, + dataManagementHandler *admin.DataManagementHandler, oauthHandler *admin.OAuthHandler, openaiOAuthHandler *admin.OpenAIOAuthHandler, geminiOAuthHandler *admin.GeminiOAuthHandler, @@ -28,6 +29,7 @@ func ProvideAdminHandlers( usageHandler *admin.UsageHandler, userAttributeHandler *admin.UserAttributeHandler, errorPassthroughHandler *admin.ErrorPassthroughHandler, + apiKeyHandler *admin.AdminAPIKeyHandler, ) *AdminHandlers { return &AdminHandlers{ Dashboard: dashboardHandler, @@ -35,6 +37,7 @@ func ProvideAdminHandlers( Group: groupHandler, Account: accountHandler, Announcement: announcementHandler, + DataManagement: dataManagementHandler, OAuth: oauthHandler, OpenAIOAuth: openaiOAuthHandler, GeminiOAuth: geminiOAuthHandler, @@ -49,12 +52,13 @@ func ProvideAdminHandlers( Usage: usageHandler, UserAttribute: userAttributeHandler, ErrorPassthrough: errorPassthroughHandler, + APIKey: apiKeyHandler, } } // ProvideSystemHandler creates admin.SystemHandler with UpdateService -func ProvideSystemHandler(updateService *service.UpdateService) *admin.SystemHandler { - return admin.NewSystemHandler(updateService) +func ProvideSystemHandler(updateService *service.UpdateService, lockService *service.SystemOperationLockService) *admin.SystemHandler { + return admin.NewSystemHandler(updateService, lockService) } // ProvideSettingHandler creates SettingHandler with version from BuildInfo @@ -74,8 +78,12 @@ func ProvideHandlers( adminHandlers *AdminHandlers, gatewayHandler *GatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler, + soraGatewayHandler *SoraGatewayHandler, + soraClientHandler *SoraClientHandler, settingHandler *SettingHandler, totpHandler *TotpHandler, + _ *service.IdempotencyCoordinator, + _ *service.IdempotencyCleanupService, ) *Handlers { return &Handlers{ Auth: authHandler, @@ -88,6 +96,8 @@ func ProvideHandlers( Admin: adminHandlers, Gateway: gatewayHandler, OpenAIGateway: openaiGatewayHandler, + SoraGateway: soraGatewayHandler, + SoraClient: soraClientHandler, Setting: settingHandler, Totp: totpHandler, } @@ -105,6 +115,7 @@ var ProviderSet = wire.NewSet( NewAnnouncementHandler, NewGatewayHandler, NewOpenAIGatewayHandler, + NewSoraGatewayHandler, NewTotpHandler, ProvideSettingHandler, @@ -114,6 +125,7 @@ var ProviderSet = wire.NewSet( admin.NewGroupHandler, admin.NewAccountHandler, admin.NewAnnouncementHandler, + admin.NewDataManagementHandler, admin.NewOAuthHandler, admin.NewOpenAIOAuthHandler, admin.NewGeminiOAuthHandler, @@ -128,6 +140,7 @@ var ProviderSet = wire.NewSet( admin.NewUsageHandler, admin.NewUserAttributeHandler, admin.NewErrorPassthroughHandler, + admin.NewAdminAPIKeyHandler, // AdminHandlers and Handlers constructors ProvideAdminHandlers, diff --git a/backend/internal/integration/e2e_gateway_test.go b/backend/internal/integration/e2e_gateway_test.go index ec0b29f7..8ee3f22e 100644 --- a/backend/internal/integration/e2e_gateway_test.go +++ b/backend/internal/integration/e2e_gateway_test.go @@ -21,11 +21,18 @@ var ( // - "" (默认): 使用 /v1/messages, /v1beta/models(混合模式,可调度 antigravity 账户) // - "/antigravity": 使用 /antigravity/v1/messages, /antigravity/v1beta/models(非混合模式,仅 antigravity 账户) endpointPrefix = getEnv("ENDPOINT_PREFIX", "") - claudeAPIKey = "sk-8e572bc3b3de92ace4f41f4256c28600ca11805732a7b693b5c44741346bbbb3" - geminiAPIKey = "sk-5950197a2085b38bbe5a1b229cc02b8ece914963fc44cacc06d497ae8b87410f" testInterval = 1 * time.Second // 测试间隔,防止限流 ) +const ( + // 注意:E2E 测试请使用环境变量注入密钥,避免任何凭证进入仓库历史。 + // 例如: + // export CLAUDE_API_KEY="sk-..." + // export GEMINI_API_KEY="sk-..." + claudeAPIKeyEnv = "CLAUDE_API_KEY" + geminiAPIKeyEnv = "GEMINI_API_KEY" +) + func getEnv(key, defaultVal string) string { if v := os.Getenv(key); v != "" { return v @@ -65,16 +72,45 @@ func TestMain(m *testing.M) { if endpointPrefix != "" { mode = "Antigravity 模式" } - fmt.Printf("\n🚀 E2E Gateway Tests - %s (prefix=%q, %s)\n\n", baseURL, endpointPrefix, mode) + claudeKeySet := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv)) != "" + geminiKeySet := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv)) != "" + fmt.Printf("\n🚀 E2E Gateway Tests - %s (prefix=%q, %s, %s=%v, %s=%v)\n\n", + baseURL, + endpointPrefix, + mode, + claudeAPIKeyEnv, + claudeKeySet, + geminiAPIKeyEnv, + geminiKeySet, + ) os.Exit(m.Run()) } +func requireClaudeAPIKey(t *testing.T) string { + t.Helper() + key := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv)) + if key == "" { + t.Skipf("未设置 %s,跳过 Claude 相关 E2E 测试", claudeAPIKeyEnv) + } + return key +} + +func requireGeminiAPIKey(t *testing.T) string { + t.Helper() + key := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv)) + if key == "" { + t.Skipf("未设置 %s,跳过 Gemini 相关 E2E 测试", geminiAPIKeyEnv) + } + return key +} + // TestClaudeModelsList 测试 GET /v1/models func TestClaudeModelsList(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) url := baseURL + endpointPrefix + "/v1/models" req, _ := http.NewRequest("GET", url, nil) - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) @@ -106,10 +142,11 @@ func TestClaudeModelsList(t *testing.T) { // TestGeminiModelsList 测试 GET /v1beta/models func TestGeminiModelsList(t *testing.T) { + geminiKey := requireGeminiAPIKey(t) url := baseURL + endpointPrefix + "/v1beta/models" req, _ := http.NewRequest("GET", url, nil) - req.Header.Set("Authorization", "Bearer "+geminiAPIKey) + req.Header.Set("Authorization", "Bearer "+geminiKey) client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) @@ -137,21 +174,22 @@ func TestGeminiModelsList(t *testing.T) { // TestClaudeMessages 测试 Claude /v1/messages 接口 func TestClaudeMessages(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) for i, model := range claudeModels { if i > 0 { time.Sleep(testInterval) } t.Run(model+"_非流式", func(t *testing.T) { - testClaudeMessage(t, model, false) + testClaudeMessage(t, claudeKey, model, false) }) time.Sleep(testInterval) t.Run(model+"_流式", func(t *testing.T) { - testClaudeMessage(t, model, true) + testClaudeMessage(t, claudeKey, model, true) }) } } -func testClaudeMessage(t *testing.T, model string, stream bool) { +func testClaudeMessage(t *testing.T, claudeKey string, model string, stream bool) { url := baseURL + endpointPrefix + "/v1/messages" payload := map[string]any{ @@ -166,7 +204,7 @@ func testClaudeMessage(t *testing.T, model string, stream bool) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) req.Header.Set("anthropic-version", "2023-06-01") client := &http.Client{Timeout: 60 * time.Second} @@ -213,21 +251,22 @@ func testClaudeMessage(t *testing.T, model string, stream bool) { // TestGeminiGenerateContent 测试 Gemini /v1beta/models/:model 接口 func TestGeminiGenerateContent(t *testing.T) { + geminiKey := requireGeminiAPIKey(t) for i, model := range geminiModels { if i > 0 { time.Sleep(testInterval) } t.Run(model+"_非流式", func(t *testing.T) { - testGeminiGenerate(t, model, false) + testGeminiGenerate(t, geminiKey, model, false) }) time.Sleep(testInterval) t.Run(model+"_流式", func(t *testing.T) { - testGeminiGenerate(t, model, true) + testGeminiGenerate(t, geminiKey, model, true) }) } } -func testGeminiGenerate(t *testing.T, model string, stream bool) { +func testGeminiGenerate(t *testing.T, geminiKey string, model string, stream bool) { action := "generateContent" if stream { action = "streamGenerateContent" @@ -254,7 +293,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+geminiAPIKey) + req.Header.Set("Authorization", "Bearer "+geminiKey) client := &http.Client{Timeout: 60 * time.Second} resp, err := client.Do(req) @@ -301,6 +340,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) { // TestClaudeMessagesWithComplexTools 测试带复杂工具 schema 的请求 // 模拟 Claude Code 发送的请求,包含需要清理的 JSON Schema 字段 func TestClaudeMessagesWithComplexTools(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) // 测试模型列表(只测试几个代表性模型) models := []string{ "claude-opus-4-5-20251101", // Claude 模型 @@ -312,12 +352,12 @@ func TestClaudeMessagesWithComplexTools(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_复杂工具", func(t *testing.T) { - testClaudeMessageWithTools(t, model) + testClaudeMessageWithTools(t, claudeKey, model) }) } } -func testClaudeMessageWithTools(t *testing.T, model string) { +func testClaudeMessageWithTools(t *testing.T, claudeKey string, model string) { url := baseURL + endpointPrefix + "/v1/messages" // 构造包含复杂 schema 的工具定义(模拟 Claude Code 的工具) @@ -473,7 +513,7 @@ func testClaudeMessageWithTools(t *testing.T, model string) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) req.Header.Set("anthropic-version", "2023-06-01") client := &http.Client{Timeout: 60 * time.Second} @@ -519,6 +559,7 @@ func testClaudeMessageWithTools(t *testing.T, model string) { // 验证:当历史 assistant 消息包含 tool_use 但没有 signature 时, // 系统应自动添加 dummy thought_signature 避免 Gemini 400 错误 func TestClaudeMessagesWithThinkingAndTools(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) models := []string{ "claude-haiku-4-5-20251001", // gemini-3-flash } @@ -527,12 +568,12 @@ func TestClaudeMessagesWithThinkingAndTools(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_thinking模式工具调用", func(t *testing.T) { - testClaudeThinkingWithToolHistory(t, model) + testClaudeThinkingWithToolHistory(t, claudeKey, model) }) } } -func testClaudeThinkingWithToolHistory(t *testing.T, model string) { +func testClaudeThinkingWithToolHistory(t *testing.T, claudeKey string, model string) { url := baseURL + endpointPrefix + "/v1/messages" // 模拟历史对话:用户请求 → assistant 调用工具 → 工具返回 → 继续对话 @@ -600,7 +641,7 @@ func testClaudeThinkingWithToolHistory(t *testing.T, model string) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) req.Header.Set("anthropic-version", "2023-06-01") client := &http.Client{Timeout: 60 * time.Second} @@ -649,6 +690,7 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) { if endpointPrefix != "/antigravity" { t.Skip("仅在 Antigravity 模式下运行") } + claudeKey := requireClaudeAPIKey(t) // 测试通过 Claude 端点调用 Gemini 模型 geminiViaClaude := []string{ @@ -664,11 +706,11 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_通过Claude端点", func(t *testing.T) { - testClaudeMessage(t, model, false) + testClaudeMessage(t, claudeKey, model, false) }) time.Sleep(testInterval) t.Run(model+"_通过Claude端点_流式", func(t *testing.T) { - testClaudeMessage(t, model, true) + testClaudeMessage(t, claudeKey, model, true) }) } } @@ -676,6 +718,7 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) { // TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景 // 验证:Gemini 模型接受没有 signature 的 thinking block func TestClaudeMessagesWithNoSignature(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) models := []string{ "claude-haiku-4-5-20251001", // gemini-3-flash - 支持无 signature } @@ -684,12 +727,12 @@ func TestClaudeMessagesWithNoSignature(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_无signature", func(t *testing.T) { - testClaudeWithNoSignature(t, model) + testClaudeWithNoSignature(t, claudeKey, model) }) } } -func testClaudeWithNoSignature(t *testing.T, model string) { +func testClaudeWithNoSignature(t *testing.T, claudeKey string, model string) { url := baseURL + endpointPrefix + "/v1/messages" // 模拟历史对话包含 thinking block 但没有 signature @@ -732,7 +775,7 @@ func testClaudeWithNoSignature(t *testing.T, model string) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) req.Header.Set("anthropic-version", "2023-06-01") client := &http.Client{Timeout: 60 * time.Second} @@ -777,6 +820,7 @@ func TestGeminiEndpointWithClaudeModel(t *testing.T) { if endpointPrefix != "/antigravity" { t.Skip("仅在 Antigravity 模式下运行") } + geminiKey := requireGeminiAPIKey(t) // 测试通过 Gemini 端点调用 Claude 模型 claudeViaGemini := []string{ @@ -789,11 +833,11 @@ func TestGeminiEndpointWithClaudeModel(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_通过Gemini端点", func(t *testing.T) { - testGeminiGenerate(t, model, false) + testGeminiGenerate(t, geminiKey, model, false) }) time.Sleep(testInterval) t.Run(model+"_通过Gemini端点_流式", func(t *testing.T) { - testGeminiGenerate(t, model, true) + testGeminiGenerate(t, geminiKey, model, true) }) } } diff --git a/backend/internal/integration/e2e_helpers_test.go b/backend/internal/integration/e2e_helpers_test.go new file mode 100644 index 00000000..7d266bcb --- /dev/null +++ b/backend/internal/integration/e2e_helpers_test.go @@ -0,0 +1,48 @@ +//go:build e2e + +package integration + +import ( + "os" + "strings" + "testing" +) + +// ============================================================================= +// E2E Mock 模式支持 +// ============================================================================= +// 当 E2E_MOCK=true 时,使用本地 Mock 响应替代真实 API 调用。 +// 这允许在没有真实 API Key 的环境(如 CI)中验证基本的请求/响应流程。 + +// isMockMode 检查是否启用 Mock 模式 +func isMockMode() bool { + return strings.EqualFold(os.Getenv("E2E_MOCK"), "true") +} + +// skipIfNoRealAPI 如果未配置真实 API Key 且不在 Mock 模式,则跳过测试 +func skipIfNoRealAPI(t *testing.T) { + t.Helper() + if isMockMode() { + return // Mock 模式下不跳过 + } + claudeKey := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv)) + geminiKey := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv)) + if claudeKey == "" && geminiKey == "" { + t.Skip("未设置 API Key 且未启用 Mock 模式,跳过测试") + } +} + +// ============================================================================= +// API Key 脱敏(Task 6.10) +// ============================================================================= + +// safeLogKey 安全地记录 API Key(仅显示前 8 位) +func safeLogKey(t *testing.T, prefix string, key string) { + t.Helper() + key = strings.TrimSpace(key) + if len(key) <= 8 { + t.Logf("%s: ***(长度: %d)", prefix, len(key)) + return + } + t.Logf("%s: %s...(长度: %d)", prefix, key[:8], len(key)) +} diff --git a/backend/internal/integration/e2e_user_flow_test.go b/backend/internal/integration/e2e_user_flow_test.go new file mode 100644 index 00000000..5489d0a3 --- /dev/null +++ b/backend/internal/integration/e2e_user_flow_test.go @@ -0,0 +1,317 @@ +//go:build e2e + +package integration + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "testing" + "time" +) + +// E2E 用户流程测试 +// 测试完整的用户操作链路:注册 → 登录 → 创建 API Key → 调用网关 → 查询用量 + +var ( + testUserEmail = "e2e-test-" + fmt.Sprintf("%d", time.Now().UnixMilli()) + "@test.local" + testUserPassword = "E2eTest@12345" + testUserName = "e2e-test-user" +) + +// TestUserRegistrationAndLogin 测试用户注册和登录流程 +func TestUserRegistrationAndLogin(t *testing.T) { + // 步骤 1: 注册新用户 + t.Run("注册新用户", func(t *testing.T) { + payload := map[string]string{ + "email": testUserEmail, + "password": testUserPassword, + "username": testUserName, + } + body, _ := json.Marshal(payload) + + resp, err := doRequest(t, "POST", "/api/auth/register", body, "") + if err != nil { + t.Skipf("注册接口不可用,跳过用户流程测试: %v", err) + return + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + // 注册可能返回 200(成功)或 400(邮箱已存在)或 403(注册已关闭) + switch resp.StatusCode { + case 200: + t.Logf("✅ 用户注册成功: %s", testUserEmail) + case 400: + t.Logf("⚠️ 用户可能已存在: %s", string(respBody)) + case 403: + t.Skipf("注册功能已关闭: %s", string(respBody)) + default: + t.Logf("⚠️ 注册返回 HTTP %d: %s(继续尝试登录)", resp.StatusCode, string(respBody)) + } + }) + + // 步骤 2: 登录获取 JWT + var accessToken string + t.Run("用户登录获取JWT", func(t *testing.T) { + payload := map[string]string{ + "email": testUserEmail, + "password": testUserPassword, + } + body, _ := json.Marshal(payload) + + resp, err := doRequest(t, "POST", "/api/auth/login", body, "") + if err != nil { + t.Fatalf("登录请求失败: %v", err) + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != 200 { + t.Skipf("登录失败 HTTP %d: %s(可能需要先注册用户)", resp.StatusCode, string(respBody)) + return + } + + var result map[string]any + if err := json.Unmarshal(respBody, &result); err != nil { + t.Fatalf("解析登录响应失败: %v", err) + } + + // 尝试从标准响应格式获取 token + if token, ok := result["access_token"].(string); ok && token != "" { + accessToken = token + } else if data, ok := result["data"].(map[string]any); ok { + if token, ok := data["access_token"].(string); ok { + accessToken = token + } + } + + if accessToken == "" { + t.Skipf("未获取到 access_token,响应: %s", string(respBody)) + return + } + + // 验证 token 不为空且格式基本正确 + if len(accessToken) < 10 { + t.Fatalf("access_token 格式异常: %s", accessToken) + } + + t.Logf("✅ 登录成功,获取 JWT(长度: %d)", len(accessToken)) + }) + + if accessToken == "" { + t.Skip("未获取到 JWT,跳过后续测试") + return + } + + // 步骤 3: 使用 JWT 获取当前用户信息 + t.Run("获取当前用户信息", func(t *testing.T) { + resp, err := doRequest(t, "GET", "/api/user/me", nil, accessToken) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + t.Logf("✅ 成功获取用户信息") + }) +} + +// TestAPIKeyLifecycle 测试 API Key 的创建和使用 +func TestAPIKeyLifecycle(t *testing.T) { + // 先登录获取 JWT + accessToken := loginTestUser(t) + if accessToken == "" { + t.Skip("无法登录,跳过 API Key 生命周期测试") + return + } + + var apiKey string + + // 步骤 1: 创建 API Key + t.Run("创建API_Key", func(t *testing.T) { + payload := map[string]string{ + "name": "e2e-test-key-" + fmt.Sprintf("%d", time.Now().UnixMilli()), + } + body, _ := json.Marshal(payload) + + resp, err := doRequest(t, "POST", "/api/keys", body, accessToken) + if err != nil { + t.Fatalf("创建 API Key 请求失败: %v", err) + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != 200 { + t.Skipf("创建 API Key 失败 HTTP %d: %s", resp.StatusCode, string(respBody)) + return + } + + var result map[string]any + if err := json.Unmarshal(respBody, &result); err != nil { + t.Fatalf("解析响应失败: %v", err) + } + + // 从响应中提取 key + if key, ok := result["key"].(string); ok { + apiKey = key + } else if data, ok := result["data"].(map[string]any); ok { + if key, ok := data["key"].(string); ok { + apiKey = key + } + } + + if apiKey == "" { + t.Skipf("未获取到 API Key,响应: %s", string(respBody)) + return + } + + // 验证 API Key 脱敏日志(只显示前 8 位) + masked := apiKey + if len(masked) > 8 { + masked = masked[:8] + "..." + } + t.Logf("✅ API Key 创建成功: %s", masked) + }) + + if apiKey == "" { + t.Skip("未创建 API Key,跳过后续测试") + return + } + + // 步骤 2: 使用 API Key 调用网关(需要 Claude 或 Gemini 可用) + t.Run("使用API_Key调用网关", func(t *testing.T) { + // 尝试调用 models 列表(最轻量的 API 调用) + resp, err := doRequest(t, "GET", "/v1/models", nil, apiKey) + if err != nil { + t.Fatalf("网关请求失败: %v", err) + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + // 可能返回 200(成功)或 402(余额不足)或 403(无可用账户) + switch { + case resp.StatusCode == 200: + t.Logf("✅ API Key 网关调用成功") + case resp.StatusCode == 402: + t.Logf("⚠️ 余额不足,但 API Key 认证通过") + case resp.StatusCode == 403: + t.Logf("⚠️ 无可用账户,但 API Key 认证通过") + default: + t.Logf("⚠️ 网关返回 HTTP %d: %s", resp.StatusCode, string(respBody)) + } + }) + + // 步骤 3: 查询用量记录 + t.Run("查询用量记录", func(t *testing.T) { + resp, err := doRequest(t, "GET", "/api/usage/dashboard", nil, accessToken) + if err != nil { + t.Fatalf("用量查询请求失败: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + t.Logf("⚠️ 用量查询返回 HTTP %d: %s", resp.StatusCode, string(body)) + return + } + + t.Logf("✅ 用量查询成功") + }) +} + +// ============================================================================= +// 辅助函数 +// ============================================================================= + +func doRequest(t *testing.T, method, path string, body []byte, token string) (*http.Response, error) { + t.Helper() + + url := baseURL + path + var bodyReader io.Reader + if body != nil { + bodyReader = bytes.NewReader(body) + } + + req, err := http.NewRequest(method, url, bodyReader) + if err != nil { + return nil, fmt.Errorf("创建请求失败: %w", err) + } + + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + + client := &http.Client{Timeout: 30 * time.Second} + return client.Do(req) +} + +func loginTestUser(t *testing.T) string { + t.Helper() + + // 先尝试用管理员账户登录 + adminEmail := getEnv("ADMIN_EMAIL", "admin@sub2api.local") + adminPassword := getEnv("ADMIN_PASSWORD", "") + + if adminPassword == "" { + // 尝试用测试用户 + adminEmail = testUserEmail + adminPassword = testUserPassword + } + + payload := map[string]string{ + "email": adminEmail, + "password": adminPassword, + } + body, _ := json.Marshal(payload) + + resp, err := doRequest(t, "POST", "/api/auth/login", body, "") + if err != nil { + return "" + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + return "" + } + + respBody, _ := io.ReadAll(resp.Body) + var result map[string]any + if err := json.Unmarshal(respBody, &result); err != nil { + return "" + } + + if token, ok := result["access_token"].(string); ok { + return token + } + if data, ok := result["data"].(map[string]any); ok { + if token, ok := data["access_token"].(string); ok { + return token + } + } + + return "" +} + +// redactAPIKey API Key 脱敏,只显示前 8 位 +func redactAPIKey(key string) string { + key = strings.TrimSpace(key) + if len(key) <= 8 { + return "***" + } + return key[:8] + "..." +} diff --git a/backend/internal/middleware/rate_limiter_test.go b/backend/internal/middleware/rate_limiter_test.go index 0c379c0f..e362274f 100644 --- a/backend/internal/middleware/rate_limiter_test.go +++ b/backend/internal/middleware/rate_limiter_test.go @@ -60,6 +60,49 @@ func TestRateLimiterFailureModes(t *testing.T) { require.Equal(t, http.StatusTooManyRequests, recorder.Code) } +func TestRateLimiterDifferentIPsIndependent(t *testing.T) { + gin.SetMode(gin.TestMode) + + callCounts := make(map[string]int64) + originalRun := rateLimitRun + rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) { + callCounts[key]++ + return callCounts[key], false, nil + } + t.Cleanup(func() { + rateLimitRun = originalRun + }) + + limiter := NewRateLimiter(redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"})) + + router := gin.New() + router.Use(limiter.Limit("api", 1, time.Second)) + router.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + // 第一个 IP 的请求应通过 + req1 := httptest.NewRequest(http.MethodGet, "/test", nil) + req1.RemoteAddr = "10.0.0.1:1234" + rec1 := httptest.NewRecorder() + router.ServeHTTP(rec1, req1) + require.Equal(t, http.StatusOK, rec1.Code, "第一个 IP 的第一次请求应通过") + + // 第二个 IP 的请求应独立通过(不受第一个 IP 的计数影响) + req2 := httptest.NewRequest(http.MethodGet, "/test", nil) + req2.RemoteAddr = "10.0.0.2:5678" + rec2 := httptest.NewRecorder() + router.ServeHTTP(rec2, req2) + require.Equal(t, http.StatusOK, rec2.Code, "第二个 IP 的第一次请求应独立通过") + + // 第一个 IP 的第二次请求应被限流 + req3 := httptest.NewRequest(http.MethodGet, "/test", nil) + req3.RemoteAddr = "10.0.0.1:1234" + rec3 := httptest.NewRecorder() + router.ServeHTTP(rec3, req3) + require.Equal(t, http.StatusTooManyRequests, rec3.Code, "第一个 IP 的第二次请求应被限流") +} + func TestRateLimiterSuccessAndLimit(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go index 7c127b90..7cc68060 100644 --- a/backend/internal/pkg/antigravity/claude_types.go +++ b/backend/internal/pkg/antigravity/claude_types.go @@ -151,6 +151,9 @@ var claudeModels = []modelDef{ {ID: "claude-opus-4-5-thinking", DisplayName: "Claude Opus 4.5 Thinking", CreatedAt: "2025-11-01T00:00:00Z"}, {ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"}, {ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"}, + {ID: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-05T00:00:00Z"}, + {ID: "claude-opus-4-6-thinking", DisplayName: "Claude Opus 4.6 Thinking", CreatedAt: "2026-02-05T00:00:00Z"}, + {ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", CreatedAt: "2026-02-17T00:00:00Z"}, } // Antigravity 支持的 Gemini 模型 @@ -161,6 +164,10 @@ var geminiModels = []modelDef{ {ID: "gemini-3-flash", DisplayName: "Gemini 3 Flash", CreatedAt: "2025-06-01T00:00:00Z"}, {ID: "gemini-3-pro-low", DisplayName: "Gemini 3 Pro Low", CreatedAt: "2025-06-01T00:00:00Z"}, {ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"}, + {ID: "gemini-3.1-pro-low", DisplayName: "Gemini 3.1 Pro Low", CreatedAt: "2026-02-19T00:00:00Z"}, + {ID: "gemini-3.1-pro-high", DisplayName: "Gemini 3.1 Pro High", CreatedAt: "2026-02-19T00:00:00Z"}, + {ID: "gemini-3.1-flash-image", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: "2026-02-19T00:00:00Z"}, + {ID: "gemini-3.1-flash-image-preview", DisplayName: "Gemini 3.1 Flash Image Preview", CreatedAt: "2026-02-19T00:00:00Z"}, {ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"}, {ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"}, } diff --git a/backend/internal/pkg/antigravity/claude_types_test.go b/backend/internal/pkg/antigravity/claude_types_test.go new file mode 100644 index 00000000..f7cb0a24 --- /dev/null +++ b/backend/internal/pkg/antigravity/claude_types_test.go @@ -0,0 +1,26 @@ +package antigravity + +import "testing" + +func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) { + t.Parallel() + + models := DefaultModels() + byID := make(map[string]ClaudeModel, len(models)) + for _, m := range models { + byID[m.ID] = m + } + + requiredIDs := []string{ + "claude-opus-4-6-thinking", + "gemini-3.1-flash-image", + "gemini-3.1-flash-image-preview", + "gemini-3-pro-image", // legacy compatibility + } + + for _, id := range requiredIDs { + if _, ok := byID[id]; !ok { + t.Fatalf("expected model %q to be exposed in DefaultModels", id) + } + } +} diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index ac32fae5..d46bbc45 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -14,6 +14,9 @@ import ( "net/url" "strings" "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" ) // NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点) @@ -33,7 +36,7 @@ func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken stri // 基础 Headers(与 Antigravity-Manager 保持一致,只设置这 3 个) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("User-Agent", UserAgent) + req.Header.Set("User-Agent", GetUserAgent()) return req, nil } @@ -149,22 +152,26 @@ type Client struct { httpClient *http.Client } -func NewClient(proxyURL string) *Client { +func NewClient(proxyURL string) (*Client, error) { client := &http.Client{ Timeout: 30 * time.Second, } - if strings.TrimSpace(proxyURL) != "" { - if proxyURLParsed, err := url.Parse(proxyURL); err == nil { - client.Transport = &http.Transport{ - Proxy: http.ProxyURL(proxyURLParsed), - } + _, parsed, err := proxyurl.Parse(proxyURL) + if err != nil { + return nil, err + } + if parsed != nil { + transport := &http.Transport{} + if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil { + return nil, fmt.Errorf("configure proxy: %w", err) } + client.Transport = transport } return &Client{ httpClient: client, - } + }, nil } // isConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝) @@ -204,9 +211,14 @@ func shouldFallbackToNextURL(err error, statusCode int) bool { // ExchangeCode 用 authorization code 交换 token func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) { + clientSecret, err := getClientSecret() + if err != nil { + return nil, err + } + params := url.Values{} params.Set("client_id", ClientID) - params.Set("client_secret", ClientSecret) + params.Set("client_secret", clientSecret) params.Set("code", code) params.Set("redirect_uri", RedirectURI) params.Set("grant_type", "authorization_code") @@ -243,9 +255,14 @@ func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (* // RefreshToken 刷新 access_token func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) { + clientSecret, err := getClientSecret() + if err != nil { + return nil, err + } + params := url.Values{} params.Set("client_id", ClientID) - params.Set("client_secret", ClientSecret) + params.Set("client_secret", clientSecret) params.Set("refresh_token", refreshToken) params.Set("grant_type", "refresh_token") @@ -333,7 +350,7 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC } req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", UserAgent) + req.Header.Set("User-Agent", GetUserAgent()) resp, err := c.httpClient.Do(req) if err != nil { @@ -412,7 +429,7 @@ func (c *Client) OnboardUser(ctx context.Context, accessToken, tierID string) (s } req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", UserAgent) + req.Header.Set("User-Agent", GetUserAgent()) resp, err := c.httpClient.Do(req) if err != nil { @@ -532,7 +549,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI } req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", UserAgent) + req.Header.Set("User-Agent", GetUserAgent()) resp, err := c.httpClient.Do(req) if err != nil { diff --git a/backend/internal/pkg/antigravity/client_test.go b/backend/internal/pkg/antigravity/client_test.go index ac30093d..20b57833 100644 --- a/backend/internal/pkg/antigravity/client_test.go +++ b/backend/internal/pkg/antigravity/client_test.go @@ -1,9 +1,1703 @@ +//go:build unit + package antigravity import ( + "context" + "encoding/json" + "fmt" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" "testing" + "time" ) +// --------------------------------------------------------------------------- +// NewAPIRequestWithURL +// --------------------------------------------------------------------------- + +func TestNewAPIRequestWithURL_普通请求(t *testing.T) { + ctx := context.Background() + baseURL := "https://example.com" + action := "generateContent" + token := "test-token" + body := []byte(`{"prompt":"hello"}`) + + req, err := NewAPIRequestWithURL(ctx, baseURL, action, token, body) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + + // 验证 URL 不含 ?alt=sse + expectedURL := "https://example.com/v1internal:generateContent" + if req.URL.String() != expectedURL { + t.Errorf("URL 不匹配: got %s, want %s", req.URL.String(), expectedURL) + } + + // 验证请求方法 + if req.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s, want POST", req.Method) + } + + // 验证 Headers + if ct := req.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("Content-Type 不匹配: got %s", ct) + } + if auth := req.Header.Get("Authorization"); auth != "Bearer test-token" { + t.Errorf("Authorization 不匹配: got %s", auth) + } + if ua := req.Header.Get("User-Agent"); ua != GetUserAgent() { + t.Errorf("User-Agent 不匹配: got %s, want %s", ua, GetUserAgent()) + } +} + +func TestNewAPIRequestWithURL_流式请求(t *testing.T) { + ctx := context.Background() + baseURL := "https://example.com" + action := "streamGenerateContent" + token := "tok" + body := []byte(`{}`) + + req, err := NewAPIRequestWithURL(ctx, baseURL, action, token, body) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + + expectedURL := "https://example.com/v1internal:streamGenerateContent?alt=sse" + if req.URL.String() != expectedURL { + t.Errorf("URL 不匹配: got %s, want %s", req.URL.String(), expectedURL) + } +} + +func TestNewAPIRequestWithURL_空Body(t *testing.T) { + ctx := context.Background() + req, err := NewAPIRequestWithURL(ctx, "https://example.com", "test", "tok", nil) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + if req.Body == nil { + t.Error("Body 应该非 nil(bytes.NewReader(nil) 会返回空 reader)") + } +} + +// --------------------------------------------------------------------------- +// NewAPIRequest +// --------------------------------------------------------------------------- + +func TestNewAPIRequest_使用默认URL(t *testing.T) { + ctx := context.Background() + req, err := NewAPIRequest(ctx, "generateContent", "tok", []byte(`{}`)) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + + expected := BaseURL + "/v1internal:generateContent" + if req.URL.String() != expected { + t.Errorf("URL 不匹配: got %s, want %s", req.URL.String(), expected) + } +} + +// --------------------------------------------------------------------------- +// TierInfo.UnmarshalJSON +// --------------------------------------------------------------------------- + +func TestTierInfo_UnmarshalJSON_字符串格式(t *testing.T) { + data := []byte(`"free-tier"`) + var tier TierInfo + if err := tier.UnmarshalJSON(data); err != nil { + t.Fatalf("反序列化失败: %v", err) + } + if tier.ID != "free-tier" { + t.Errorf("ID 不匹配: got %s, want free-tier", tier.ID) + } + if tier.Name != "" { + t.Errorf("Name 应为空: got %s", tier.Name) + } +} + +func TestTierInfo_UnmarshalJSON_对象格式(t *testing.T) { + data := []byte(`{"id":"g1-pro-tier","name":"Pro","description":"Pro plan"}`) + var tier TierInfo + if err := tier.UnmarshalJSON(data); err != nil { + t.Fatalf("反序列化失败: %v", err) + } + if tier.ID != "g1-pro-tier" { + t.Errorf("ID 不匹配: got %s, want g1-pro-tier", tier.ID) + } + if tier.Name != "Pro" { + t.Errorf("Name 不匹配: got %s, want Pro", tier.Name) + } + if tier.Description != "Pro plan" { + t.Errorf("Description 不匹配: got %s, want Pro plan", tier.Description) + } +} + +func TestTierInfo_UnmarshalJSON_null(t *testing.T) { + data := []byte(`null`) + var tier TierInfo + if err := tier.UnmarshalJSON(data); err != nil { + t.Fatalf("反序列化 null 失败: %v", err) + } + if tier.ID != "" { + t.Errorf("null 场景下 ID 应为空: got %s", tier.ID) + } +} + +func TestTierInfo_UnmarshalJSON_空数据(t *testing.T) { + data := []byte(``) + var tier TierInfo + if err := tier.UnmarshalJSON(data); err != nil { + t.Fatalf("反序列化空数据失败: %v", err) + } + if tier.ID != "" { + t.Errorf("空数据场景下 ID 应为空: got %s", tier.ID) + } +} + +func TestTierInfo_UnmarshalJSON_空格包裹null(t *testing.T) { + data := []byte(` null `) + var tier TierInfo + if err := tier.UnmarshalJSON(data); err != nil { + t.Fatalf("反序列化空格 null 失败: %v", err) + } + if tier.ID != "" { + t.Errorf("空格 null 场景下 ID 应为空: got %s", tier.ID) + } +} + +func TestTierInfo_UnmarshalJSON_通过JSON嵌套结构(t *testing.T) { + // 模拟 LoadCodeAssistResponse 中的嵌套反序列化 + jsonData := `{"currentTier":"free-tier","paidTier":{"id":"g1-ultra-tier","name":"Ultra"}}` + var resp LoadCodeAssistResponse + if err := json.Unmarshal([]byte(jsonData), &resp); err != nil { + t.Fatalf("反序列化嵌套结构失败: %v", err) + } + if resp.CurrentTier == nil || resp.CurrentTier.ID != "free-tier" { + t.Errorf("CurrentTier 不匹配: got %+v", resp.CurrentTier) + } + if resp.PaidTier == nil || resp.PaidTier.ID != "g1-ultra-tier" { + t.Errorf("PaidTier 不匹配: got %+v", resp.PaidTier) + } +} + +// --------------------------------------------------------------------------- +// LoadCodeAssistResponse.GetTier +// --------------------------------------------------------------------------- + +func TestGetTier_PaidTier优先(t *testing.T) { + resp := &LoadCodeAssistResponse{ + CurrentTier: &TierInfo{ID: "free-tier"}, + PaidTier: &TierInfo{ID: "g1-pro-tier"}, + } + if got := resp.GetTier(); got != "g1-pro-tier" { + t.Errorf("应返回 paidTier: got %s", got) + } +} + +func TestGetTier_回退到CurrentTier(t *testing.T) { + resp := &LoadCodeAssistResponse{ + CurrentTier: &TierInfo{ID: "free-tier"}, + } + if got := resp.GetTier(); got != "free-tier" { + t.Errorf("应返回 currentTier: got %s", got) + } +} + +func TestGetTier_PaidTier为空ID(t *testing.T) { + resp := &LoadCodeAssistResponse{ + CurrentTier: &TierInfo{ID: "free-tier"}, + PaidTier: &TierInfo{ID: ""}, + } + // paidTier.ID 为空时应回退到 currentTier + if got := resp.GetTier(); got != "free-tier" { + t.Errorf("paidTier.ID 为空时应回退到 currentTier: got %s", got) + } +} + +func TestGetTier_两者都为nil(t *testing.T) { + resp := &LoadCodeAssistResponse{} + if got := resp.GetTier(); got != "" { + t.Errorf("两者都为 nil 时应返回空字符串: got %s", got) + } +} + +// --------------------------------------------------------------------------- +// NewClient +// --------------------------------------------------------------------------- + +func mustNewClient(t *testing.T, proxyURL string) *Client { + t.Helper() + client, err := NewClient(proxyURL) + if err != nil { + t.Fatalf("NewClient(%q) failed: %v", proxyURL, err) + } + return client +} + +func TestNewClient_无代理(t *testing.T) { + client, err := NewClient("") + if err != nil { + t.Fatalf("NewClient 返回错误: %v", err) + } + if client == nil { + t.Fatal("NewClient 返回 nil") + } + if client.httpClient == nil { + t.Fatal("httpClient 为 nil") + } + if client.httpClient.Timeout != 30*time.Second { + t.Errorf("Timeout 不匹配: got %v, want 30s", client.httpClient.Timeout) + } + // 无代理时 Transport 应为 nil(使用默认) + if client.httpClient.Transport != nil { + t.Error("无代理时 Transport 应为 nil") + } +} + +func TestNewClient_有代理(t *testing.T) { + client, err := NewClient("http://proxy.example.com:8080") + if err != nil { + t.Fatalf("NewClient 返回错误: %v", err) + } + if client == nil { + t.Fatal("NewClient 返回 nil") + } + if client.httpClient.Transport == nil { + t.Fatal("有代理时 Transport 不应为 nil") + } +} + +func TestNewClient_空格代理(t *testing.T) { + client, err := NewClient(" ") + if err != nil { + t.Fatalf("NewClient 返回错误: %v", err) + } + if client == nil { + t.Fatal("NewClient 返回 nil") + } + // 空格代理应等同于无代理 + if client.httpClient.Transport != nil { + t.Error("空格代理 Transport 应为 nil") + } +} + +func TestNewClient_无效代理URL(t *testing.T) { + // 无效 URL 应返回 error + _, err := NewClient("://invalid") + if err == nil { + t.Fatal("无效代理 URL 应返回错误") + } + if !strings.Contains(err.Error(), "invalid proxy URL") { + t.Errorf("错误信息应包含 'invalid proxy URL': got %s", err.Error()) + } +} + +// --------------------------------------------------------------------------- +// isConnectionError +// --------------------------------------------------------------------------- + +func TestIsConnectionError_nil(t *testing.T) { + if isConnectionError(nil) { + t.Error("nil 错误不应判定为连接错误") + } +} + +func TestIsConnectionError_超时错误(t *testing.T) { + // 使用 net.OpError 包装超时 + err := &net.OpError{ + Op: "dial", + Net: "tcp", + Err: &timeoutError{}, + } + if !isConnectionError(err) { + t.Error("超时错误应判定为连接错误") + } +} + +// timeoutError 实现 net.Error 接口用于测试 +type timeoutError struct{} + +func (e *timeoutError) Error() string { return "timeout" } +func (e *timeoutError) Timeout() bool { return true } +func (e *timeoutError) Temporary() bool { return true } + +func TestIsConnectionError_netOpError(t *testing.T) { + err := &net.OpError{ + Op: "dial", + Net: "tcp", + Err: fmt.Errorf("connection refused"), + } + if !isConnectionError(err) { + t.Error("net.OpError 应判定为连接错误") + } +} + +func TestIsConnectionError_urlError(t *testing.T) { + err := &url.Error{ + Op: "Get", + URL: "https://example.com", + Err: fmt.Errorf("some error"), + } + if !isConnectionError(err) { + t.Error("url.Error 应判定为连接错误") + } +} + +func TestIsConnectionError_普通错误(t *testing.T) { + err := fmt.Errorf("some random error") + if isConnectionError(err) { + t.Error("普通错误不应判定为连接错误") + } +} + +func TestIsConnectionError_包装的netOpError(t *testing.T) { + inner := &net.OpError{ + Op: "dial", + Net: "tcp", + Err: fmt.Errorf("connection refused"), + } + err := fmt.Errorf("wrapping: %w", inner) + if !isConnectionError(err) { + t.Error("被包装的 net.OpError 应判定为连接错误") + } +} + +// --------------------------------------------------------------------------- +// shouldFallbackToNextURL +// --------------------------------------------------------------------------- + +func TestShouldFallbackToNextURL_连接错误(t *testing.T) { + err := &net.OpError{Op: "dial", Net: "tcp", Err: fmt.Errorf("refused")} + if !shouldFallbackToNextURL(err, 0) { + t.Error("连接错误应触发 URL 降级") + } +} + +func TestShouldFallbackToNextURL_状态码(t *testing.T) { + tests := []struct { + name string + statusCode int + want bool + }{ + {"429 Too Many Requests", http.StatusTooManyRequests, true}, + {"408 Request Timeout", http.StatusRequestTimeout, true}, + {"404 Not Found", http.StatusNotFound, true}, + {"500 Internal Server Error", http.StatusInternalServerError, true}, + {"502 Bad Gateway", http.StatusBadGateway, true}, + {"503 Service Unavailable", http.StatusServiceUnavailable, true}, + {"200 OK", http.StatusOK, false}, + {"201 Created", http.StatusCreated, false}, + {"400 Bad Request", http.StatusBadRequest, false}, + {"401 Unauthorized", http.StatusUnauthorized, false}, + {"403 Forbidden", http.StatusForbidden, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := shouldFallbackToNextURL(nil, tt.statusCode) + if got != tt.want { + t.Errorf("shouldFallbackToNextURL(nil, %d) = %v, want %v", tt.statusCode, got, tt.want) + } + }) + } +} + +func TestShouldFallbackToNextURL_无错误且200(t *testing.T) { + if shouldFallbackToNextURL(nil, http.StatusOK) { + t.Error("无错误且 200 不应触发 URL 降级") + } +} + +// --------------------------------------------------------------------------- +// Client.ExchangeCode (使用 httptest) +// --------------------------------------------------------------------------- + +func TestClient_ExchangeCode_成功(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 验证请求方法 + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s", r.Method) + } + // 验证 Content-Type + if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" { + t.Errorf("Content-Type 不匹配: got %s", ct) + } + // 验证请求体参数 + if err := r.ParseForm(); err != nil { + t.Fatalf("解析表单失败: %v", err) + } + if r.FormValue("client_id") != ClientID { + t.Errorf("client_id 不匹配: got %s", r.FormValue("client_id")) + } + if r.FormValue("client_secret") != "test-secret" { + t.Errorf("client_secret 不匹配: got %s", r.FormValue("client_secret")) + } + if r.FormValue("code") != "auth-code" { + t.Errorf("code 不匹配: got %s", r.FormValue("code")) + } + if r.FormValue("code_verifier") != "verifier123" { + t.Errorf("code_verifier 不匹配: got %s", r.FormValue("code_verifier")) + } + if r.FormValue("grant_type") != "authorization_code" { + t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "access-tok", + ExpiresIn: 3600, + TokenType: "Bearer", + RefreshToken: "refresh-tok", + }) + })) + defer server.Close() + + // 临时替换 TokenURL(该函数直接使用常量,需要我们通过构建自定义 client 来绕过) + // 由于 ExchangeCode 硬编码了 TokenURL,我们需要直接测试 HTTP client 的行为 + // 这里通过构造一个直接调用 mock server 的测试 + client := &Client{httpClient: server.Client()} + + // 由于 ExchangeCode 使用硬编码的 TokenURL,我们无法直接注入 mock server URL + // 需要使用 httptest 的 Transport 重定向 + originalTokenURL := TokenURL + // 我们改为直接构造请求来测试逻辑 + _ = originalTokenURL + _ = client + + // 改用直接构造请求测试 mock server 响应 + ctx := context.Background() + params := url.Values{} + params.Set("client_id", ClientID) + params.Set("client_secret", "test-secret") + params.Set("code", "auth-code") + params.Set("redirect_uri", RedirectURI) + params.Set("grant_type", "authorization_code") + params.Set("code_verifier", "verifier123") + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, strings.NewReader(params.Encode())) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := server.Client().Do(req) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("状态码不匹配: got %d", resp.StatusCode) + } + + var tokenResp TokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + t.Fatalf("解码失败: %v", err) + } + if tokenResp.AccessToken != "access-tok" { + t.Errorf("AccessToken 不匹配: got %s", tokenResp.AccessToken) + } + if tokenResp.RefreshToken != "refresh-tok" { + t.Errorf("RefreshToken 不匹配: got %s", tokenResp.RefreshToken) + } +} + +func TestClient_ExchangeCode_无ClientSecret(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "" + t.Cleanup(func() { defaultClientSecret = old }) + + client := mustNewClient(t, "") + _, err := client.ExchangeCode(context.Background(), "code", "verifier") + if err == nil { + t.Fatal("缺少 client_secret 时应返回错误") + } + if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) { + t.Errorf("错误信息应包含环境变量名: got %s", err.Error()) + } +} + +func TestClient_ExchangeCode_服务器返回错误(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"invalid_grant"}`)) + })) + defer server.Close() + + // 直接测试 mock server 的错误响应 + resp, err := server.Client().Get(server.URL) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("状态码不匹配: got %d, want 400", resp.StatusCode) + } +} + +// --------------------------------------------------------------------------- +// Client.RefreshToken (使用 httptest) +// --------------------------------------------------------------------------- + +func TestClient_RefreshToken_MockServer(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s", r.Method) + } + if err := r.ParseForm(); err != nil { + t.Fatalf("解析表单失败: %v", err) + } + if r.FormValue("grant_type") != "refresh_token" { + t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type")) + } + if r.FormValue("refresh_token") != "old-refresh-tok" { + t.Errorf("refresh_token 不匹配: got %s", r.FormValue("refresh_token")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "new-access-tok", + ExpiresIn: 3600, + TokenType: "Bearer", + }) + })) + defer server.Close() + + ctx := context.Background() + params := url.Values{} + params.Set("client_id", ClientID) + params.Set("client_secret", "test-secret") + params.Set("refresh_token", "old-refresh-tok") + params.Set("grant_type", "refresh_token") + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, strings.NewReader(params.Encode())) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := server.Client().Do(req) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("状态码不匹配: got %d", resp.StatusCode) + } + + var tokenResp TokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + t.Fatalf("解码失败: %v", err) + } + if tokenResp.AccessToken != "new-access-tok" { + t.Errorf("AccessToken 不匹配: got %s", tokenResp.AccessToken) + } +} + +func TestClient_RefreshToken_无ClientSecret(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "" + t.Cleanup(func() { defaultClientSecret = old }) + + client := mustNewClient(t, "") + _, err := client.RefreshToken(context.Background(), "refresh-tok") + if err == nil { + t.Fatal("缺少 client_secret 时应返回错误") + } +} + +// --------------------------------------------------------------------------- +// Client.GetUserInfo (使用 httptest) +// --------------------------------------------------------------------------- + +func TestClient_GetUserInfo_成功(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("请求方法不匹配: got %s", r.Method) + } + auth := r.Header.Get("Authorization") + if auth != "Bearer test-access-token" { + t.Errorf("Authorization 不匹配: got %s", auth) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(UserInfo{ + Email: "user@example.com", + Name: "Test User", + GivenName: "Test", + FamilyName: "User", + Picture: "https://example.com/photo.jpg", + }) + })) + defer server.Close() + + // 直接通过 mock server 测试 GetUserInfo 的行为逻辑 + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + req.Header.Set("Authorization", "Bearer test-access-token") + + resp, err := server.Client().Do(req) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("状态码不匹配: got %d", resp.StatusCode) + } + + var userInfo UserInfo + if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { + t.Fatalf("解码失败: %v", err) + } + if userInfo.Email != "user@example.com" { + t.Errorf("Email 不匹配: got %s", userInfo.Email) + } + if userInfo.Name != "Test User" { + t.Errorf("Name 不匹配: got %s", userInfo.Name) + } +} + +func TestClient_GetUserInfo_服务器返回错误(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid_token"}`)) + })) + defer server.Close() + + resp, err := server.Client().Get(server.URL) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusUnauthorized { + t.Errorf("状态码不匹配: got %d, want 401", resp.StatusCode) + } +} + +// --------------------------------------------------------------------------- +// TokenResponse / UserInfo JSON 序列化 +// --------------------------------------------------------------------------- + +func TestTokenResponse_JSON序列化(t *testing.T) { + jsonData := `{"access_token":"at","expires_in":3600,"token_type":"Bearer","scope":"openid","refresh_token":"rt"}` + var resp TokenResponse + if err := json.Unmarshal([]byte(jsonData), &resp); err != nil { + t.Fatalf("反序列化失败: %v", err) + } + if resp.AccessToken != "at" { + t.Errorf("AccessToken 不匹配: got %s", resp.AccessToken) + } + if resp.ExpiresIn != 3600 { + t.Errorf("ExpiresIn 不匹配: got %d", resp.ExpiresIn) + } + if resp.RefreshToken != "rt" { + t.Errorf("RefreshToken 不匹配: got %s", resp.RefreshToken) + } +} + +func TestUserInfo_JSON序列化(t *testing.T) { + jsonData := `{"email":"a@b.com","name":"Alice"}` + var info UserInfo + if err := json.Unmarshal([]byte(jsonData), &info); err != nil { + t.Fatalf("反序列化失败: %v", err) + } + if info.Email != "a@b.com" { + t.Errorf("Email 不匹配: got %s", info.Email) + } + if info.Name != "Alice" { + t.Errorf("Name 不匹配: got %s", info.Name) + } +} + +// --------------------------------------------------------------------------- +// LoadCodeAssistResponse JSON 序列化 +// --------------------------------------------------------------------------- + +func TestLoadCodeAssistResponse_完整JSON(t *testing.T) { + jsonData := `{ + "cloudaicompanionProject": "proj-123", + "currentTier": "free-tier", + "paidTier": {"id": "g1-pro-tier", "name": "Pro"}, + "ineligibleTiers": [{"tier": {"id": "g1-ultra-tier"}, "reasonCode": "INELIGIBLE_ACCOUNT"}] + }` + var resp LoadCodeAssistResponse + if err := json.Unmarshal([]byte(jsonData), &resp); err != nil { + t.Fatalf("反序列化失败: %v", err) + } + if resp.CloudAICompanionProject != "proj-123" { + t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject) + } + if resp.GetTier() != "g1-pro-tier" { + t.Errorf("GetTier 不匹配: got %s", resp.GetTier()) + } + if len(resp.IneligibleTiers) != 1 { + t.Fatalf("IneligibleTiers 数量不匹配: got %d", len(resp.IneligibleTiers)) + } + if resp.IneligibleTiers[0].ReasonCode != "INELIGIBLE_ACCOUNT" { + t.Errorf("ReasonCode 不匹配: got %s", resp.IneligibleTiers[0].ReasonCode) + } +} + +// =========================================================================== +// 以下为新增测试:真正调用 Client 方法,通过 RoundTripper 拦截 HTTP 请求 +// =========================================================================== + +// redirectRoundTripper 将请求中特定前缀的 URL 重定向到 httptest server +type redirectRoundTripper struct { + // 原始 URL 前缀 -> 替换目标 URL 的映射 + redirects map[string]string + transport http.RoundTripper +} + +func (rt *redirectRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + originalURL := req.URL.String() + for prefix, target := range rt.redirects { + if strings.HasPrefix(originalURL, prefix) { + newURL := target + strings.TrimPrefix(originalURL, prefix) + parsed, err := url.Parse(newURL) + if err != nil { + return nil, err + } + req.URL = parsed + break + } + } + if rt.transport == nil { + return http.DefaultTransport.RoundTrip(req) + } + return rt.transport.RoundTrip(req) +} + +// newTestClientWithRedirect 创建一个 Client,将指定 URL 前缀的请求重定向到 mock server +func newTestClientWithRedirect(redirects map[string]string) *Client { + return &Client{ + httpClient: &http.Client{ + Timeout: 10 * time.Second, + Transport: &redirectRoundTripper{ + redirects: redirects, + }, + }, + } +} + +// --------------------------------------------------------------------------- +// Client.ExchangeCode - 真正调用方法的测试 +// --------------------------------------------------------------------------- + +func TestClient_ExchangeCode_Success_RealCall(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s, want POST", r.Method) + } + if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" { + t.Errorf("Content-Type 不匹配: got %s", ct) + } + if err := r.ParseForm(); err != nil { + t.Fatalf("解析表单失败: %v", err) + } + if r.FormValue("client_id") != ClientID { + t.Errorf("client_id 不匹配: got %s", r.FormValue("client_id")) + } + if r.FormValue("client_secret") != "test-secret" { + t.Errorf("client_secret 不匹配: got %s", r.FormValue("client_secret")) + } + if r.FormValue("code") != "test-auth-code" { + t.Errorf("code 不匹配: got %s", r.FormValue("code")) + } + if r.FormValue("code_verifier") != "test-verifier" { + t.Errorf("code_verifier 不匹配: got %s", r.FormValue("code_verifier")) + } + if r.FormValue("grant_type") != "authorization_code" { + t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type")) + } + if r.FormValue("redirect_uri") != RedirectURI { + t.Errorf("redirect_uri 不匹配: got %s", r.FormValue("redirect_uri")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "new-access-token", + ExpiresIn: 3600, + TokenType: "Bearer", + Scope: "openid email", + RefreshToken: "new-refresh-token", + }) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + tokenResp, err := client.ExchangeCode(context.Background(), "test-auth-code", "test-verifier") + if err != nil { + t.Fatalf("ExchangeCode 失败: %v", err) + } + if tokenResp.AccessToken != "new-access-token" { + t.Errorf("AccessToken 不匹配: got %s, want new-access-token", tokenResp.AccessToken) + } + if tokenResp.RefreshToken != "new-refresh-token" { + t.Errorf("RefreshToken 不匹配: got %s, want new-refresh-token", tokenResp.RefreshToken) + } + if tokenResp.ExpiresIn != 3600 { + t.Errorf("ExpiresIn 不匹配: got %d, want 3600", tokenResp.ExpiresIn) + } + if tokenResp.TokenType != "Bearer" { + t.Errorf("TokenType 不匹配: got %s, want Bearer", tokenResp.TokenType) + } + if tokenResp.Scope != "openid email" { + t.Errorf("Scope 不匹配: got %s, want openid email", tokenResp.Scope) + } +} + +func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"code expired"}`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + _, err := client.ExchangeCode(context.Background(), "expired-code", "verifier") + if err == nil { + t.Fatal("服务器返回 400 时应返回错误") + } + if !strings.Contains(err.Error(), "token 交换失败") { + t.Errorf("错误信息应包含 'token 交换失败': got %s", err.Error()) + } + if !strings.Contains(err.Error(), "400") { + t.Errorf("错误信息应包含状态码 400: got %s", err.Error()) + } +} + +func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{invalid json`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + _, err := client.ExchangeCode(context.Background(), "code", "verifier") + if err == nil { + t.Fatal("无效 JSON 响应应返回错误") + } + if !strings.Contains(err.Error(), "token 解析失败") { + t.Errorf("错误信息应包含 'token 解析失败': got %s", err.Error()) + } +} + +func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) // 模拟慢响应 + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // 立即取消 + + _, err := client.ExchangeCode(ctx, "code", "verifier") + if err == nil { + t.Fatal("context 取消时应返回错误") + } +} + +// --------------------------------------------------------------------------- +// Client.RefreshToken - 真正调用方法的测试 +// --------------------------------------------------------------------------- + +func TestClient_RefreshToken_Success_RealCall(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s, want POST", r.Method) + } + if err := r.ParseForm(); err != nil { + t.Fatalf("解析表单失败: %v", err) + } + if r.FormValue("grant_type") != "refresh_token" { + t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type")) + } + if r.FormValue("refresh_token") != "my-refresh-token" { + t.Errorf("refresh_token 不匹配: got %s", r.FormValue("refresh_token")) + } + if r.FormValue("client_id") != ClientID { + t.Errorf("client_id 不匹配: got %s", r.FormValue("client_id")) + } + if r.FormValue("client_secret") != "test-secret" { + t.Errorf("client_secret 不匹配: got %s", r.FormValue("client_secret")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "refreshed-access-token", + ExpiresIn: 3600, + TokenType: "Bearer", + }) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + tokenResp, err := client.RefreshToken(context.Background(), "my-refresh-token") + if err != nil { + t.Fatalf("RefreshToken 失败: %v", err) + } + if tokenResp.AccessToken != "refreshed-access-token" { + t.Errorf("AccessToken 不匹配: got %s, want refreshed-access-token", tokenResp.AccessToken) + } + if tokenResp.ExpiresIn != 3600 { + t.Errorf("ExpiresIn 不匹配: got %d, want 3600", tokenResp.ExpiresIn) + } +} + +func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"token revoked"}`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + _, err := client.RefreshToken(context.Background(), "revoked-token") + if err == nil { + t.Fatal("服务器返回 401 时应返回错误") + } + if !strings.Contains(err.Error(), "token 刷新失败") { + t.Errorf("错误信息应包含 'token 刷新失败': got %s", err.Error()) + } +} + +func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`not-json`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + _, err := client.RefreshToken(context.Background(), "refresh-tok") + if err == nil { + t.Fatal("无效 JSON 响应应返回错误") + } + if !strings.Contains(err.Error(), "token 解析失败") { + t.Errorf("错误信息应包含 'token 解析失败': got %s", err.Error()) + } +} + +func TestClient_RefreshToken_ContextCanceled_RealCall(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "test-secret" + t.Cleanup(func() { defaultClientSecret = old }) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := client.RefreshToken(ctx, "refresh-tok") + if err == nil { + t.Fatal("context 取消时应返回错误") + } +} + +// --------------------------------------------------------------------------- +// Client.GetUserInfo - 真正调用方法的测试 +// --------------------------------------------------------------------------- + +func TestClient_GetUserInfo_Success_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("请求方法不匹配: got %s, want GET", r.Method) + } + auth := r.Header.Get("Authorization") + if auth != "Bearer user-access-token" { + t.Errorf("Authorization 不匹配: got %s", auth) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(UserInfo{ + Email: "test@example.com", + Name: "Test User", + GivenName: "Test", + FamilyName: "User", + Picture: "https://example.com/avatar.jpg", + }) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + UserInfoURL: server.URL, + }) + + userInfo, err := client.GetUserInfo(context.Background(), "user-access-token") + if err != nil { + t.Fatalf("GetUserInfo 失败: %v", err) + } + if userInfo.Email != "test@example.com" { + t.Errorf("Email 不匹配: got %s, want test@example.com", userInfo.Email) + } + if userInfo.Name != "Test User" { + t.Errorf("Name 不匹配: got %s, want Test User", userInfo.Name) + } + if userInfo.GivenName != "Test" { + t.Errorf("GivenName 不匹配: got %s, want Test", userInfo.GivenName) + } + if userInfo.FamilyName != "User" { + t.Errorf("FamilyName 不匹配: got %s, want User", userInfo.FamilyName) + } + if userInfo.Picture != "https://example.com/avatar.jpg" { + t.Errorf("Picture 不匹配: got %s", userInfo.Picture) + } +} + +func TestClient_GetUserInfo_Unauthorized_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid_token"}`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + UserInfoURL: server.URL, + }) + + _, err := client.GetUserInfo(context.Background(), "bad-token") + if err == nil { + t.Fatal("服务器返回 401 时应返回错误") + } + if !strings.Contains(err.Error(), "获取用户信息失败") { + t.Errorf("错误信息应包含 '获取用户信息失败': got %s", err.Error()) + } + if !strings.Contains(err.Error(), "401") { + t.Errorf("错误信息应包含状态码 401: got %s", err.Error()) + } +} + +func TestClient_GetUserInfo_InvalidJSON_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{broken`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + UserInfoURL: server.URL, + }) + + _, err := client.GetUserInfo(context.Background(), "token") + if err == nil { + t.Fatal("无效 JSON 响应应返回错误") + } + if !strings.Contains(err.Error(), "用户信息解析失败") { + t.Errorf("错误信息应包含 '用户信息解析失败': got %s", err.Error()) + } +} + +func TestClient_GetUserInfo_ContextCanceled_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + UserInfoURL: server.URL, + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := client.GetUserInfo(ctx, "token") + if err == nil { + t.Fatal("context 取消时应返回错误") + } +} + +// --------------------------------------------------------------------------- +// Client.LoadCodeAssist - 真正调用方法的测试 +// --------------------------------------------------------------------------- + +// withMockBaseURLs 临时替换 BaseURLs,测试结束后恢复 +func withMockBaseURLs(t *testing.T, urls []string) { + t.Helper() + origBaseURLs := BaseURLs + origBaseURL := BaseURL + BaseURLs = urls + if len(urls) > 0 { + BaseURL = urls[0] + } + t.Cleanup(func() { + BaseURLs = origBaseURLs + BaseURL = origBaseURL + }) +} + +func TestClient_LoadCodeAssist_Success_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s, want POST", r.Method) + } + if !strings.HasSuffix(r.URL.Path, "/v1internal:loadCodeAssist") { + t.Errorf("URL 路径不匹配: got %s", r.URL.Path) + } + auth := r.Header.Get("Authorization") + if auth != "Bearer test-token" { + t.Errorf("Authorization 不匹配: got %s", auth) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("Content-Type 不匹配: got %s", ct) + } + if ua := r.Header.Get("User-Agent"); ua != GetUserAgent() { + t.Errorf("User-Agent 不匹配: got %s", ua) + } + + // 验证请求体 + var reqBody LoadCodeAssistRequest + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + t.Fatalf("解析请求体失败: %v", err) + } + if reqBody.Metadata.IDEType != "ANTIGRAVITY" { + t.Errorf("IDEType 不匹配: got %s, want ANTIGRAVITY", reqBody.Metadata.IDEType) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "cloudaicompanionProject": "test-project-123", + "currentTier": {"id": "free-tier", "name": "Free"}, + "paidTier": {"id": "g1-pro-tier", "name": "Pro", "description": "Pro plan"} + }`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := mustNewClient(t, "") + resp, rawResp, err := client.LoadCodeAssist(context.Background(), "test-token") + if err != nil { + t.Fatalf("LoadCodeAssist 失败: %v", err) + } + if resp.CloudAICompanionProject != "test-project-123" { + t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject) + } + if resp.GetTier() != "g1-pro-tier" { + t.Errorf("GetTier 不匹配: got %s, want g1-pro-tier", resp.GetTier()) + } + if resp.CurrentTier == nil || resp.CurrentTier.ID != "free-tier" { + t.Errorf("CurrentTier 不匹配: got %+v", resp.CurrentTier) + } + if resp.PaidTier == nil || resp.PaidTier.ID != "g1-pro-tier" { + t.Errorf("PaidTier 不匹配: got %+v", resp.PaidTier) + } + // 验证原始 JSON map + if rawResp == nil { + t.Fatal("rawResp 不应为 nil") + } + if rawResp["cloudaicompanionProject"] != "test-project-123" { + t.Errorf("rawResp cloudaicompanionProject 不匹配: got %v", rawResp["cloudaicompanionProject"]) + } +} + +func TestClient_LoadCodeAssist_HTTPError_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"error":"forbidden"}`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := mustNewClient(t, "") + _, _, err := client.LoadCodeAssist(context.Background(), "bad-token") + if err == nil { + t.Fatal("服务器返回 403 时应返回错误") + } + if !strings.Contains(err.Error(), "loadCodeAssist 失败") { + t.Errorf("错误信息应包含 'loadCodeAssist 失败': got %s", err.Error()) + } + if !strings.Contains(err.Error(), "403") { + t.Errorf("错误信息应包含状态码 403: got %s", err.Error()) + } +} + +func TestClient_LoadCodeAssist_InvalidJSON_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{not valid json!!!`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := mustNewClient(t, "") + _, _, err := client.LoadCodeAssist(context.Background(), "token") + if err == nil { + t.Fatal("无效 JSON 响应应返回错误") + } + if !strings.Contains(err.Error(), "响应解析失败") { + t.Errorf("错误信息应包含 '响应解析失败': got %s", err.Error()) + } +} + +func TestClient_LoadCodeAssist_URLFallback_RealCall(t *testing.T) { + // 第一个 server 返回 500,第二个 server 返回成功 + callCount := 0 + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":"internal"}`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "cloudaicompanionProject": "fallback-project", + "currentTier": {"id": "free-tier", "name": "Free"} + }`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := mustNewClient(t, "") + resp, _, err := client.LoadCodeAssist(context.Background(), "token") + if err != nil { + t.Fatalf("LoadCodeAssist 应在 fallback 后成功: %v", err) + } + if resp.CloudAICompanionProject != "fallback-project" { + t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject) + } + if callCount != 2 { + t.Errorf("应该调用了 2 个 server,实际调用 %d 次", callCount) + } +} + +func TestClient_LoadCodeAssist_AllURLsFail_RealCall(t *testing.T) { + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write([]byte(`{"error":"unavailable"}`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + _, _ = w.Write([]byte(`{"error":"bad_gateway"}`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := mustNewClient(t, "") + _, _, err := client.LoadCodeAssist(context.Background(), "token") + if err == nil { + t.Fatal("所有 URL 都失败时应返回错误") + } +} + +func TestClient_LoadCodeAssist_ContextCanceled_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := mustNewClient(t, "") + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, _, err := client.LoadCodeAssist(ctx, "token") + if err == nil { + t.Fatal("context 取消时应返回错误") + } +} + +// --------------------------------------------------------------------------- +// Client.FetchAvailableModels - 真正调用方法的测试 +// --------------------------------------------------------------------------- + +func TestClient_FetchAvailableModels_Success_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s, want POST", r.Method) + } + if !strings.HasSuffix(r.URL.Path, "/v1internal:fetchAvailableModels") { + t.Errorf("URL 路径不匹配: got %s", r.URL.Path) + } + auth := r.Header.Get("Authorization") + if auth != "Bearer test-token" { + t.Errorf("Authorization 不匹配: got %s", auth) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("Content-Type 不匹配: got %s", ct) + } + if ua := r.Header.Get("User-Agent"); ua != GetUserAgent() { + t.Errorf("User-Agent 不匹配: got %s", ua) + } + + // 验证请求体 + var reqBody FetchAvailableModelsRequest + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + t.Fatalf("解析请求体失败: %v", err) + } + if reqBody.Project != "project-abc" { + t.Errorf("Project 不匹配: got %s, want project-abc", reqBody.Project) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "models": { + "gemini-2.0-flash": { + "quotaInfo": { + "remainingFraction": 0.85, + "resetTime": "2025-01-01T00:00:00Z" + } + }, + "gemini-2.5-pro": { + "quotaInfo": { + "remainingFraction": 0.5 + } + } + } + }`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := mustNewClient(t, "") + resp, rawResp, err := client.FetchAvailableModels(context.Background(), "test-token", "project-abc") + if err != nil { + t.Fatalf("FetchAvailableModels 失败: %v", err) + } + if resp.Models == nil { + t.Fatal("Models 不应为 nil") + } + if len(resp.Models) != 2 { + t.Errorf("Models 数量不匹配: got %d, want 2", len(resp.Models)) + } + + flashModel, ok := resp.Models["gemini-2.0-flash"] + if !ok { + t.Fatal("缺少 gemini-2.0-flash 模型") + } + if flashModel.QuotaInfo == nil { + t.Fatal("gemini-2.0-flash QuotaInfo 不应为 nil") + } + if flashModel.QuotaInfo.RemainingFraction != 0.85 { + t.Errorf("RemainingFraction 不匹配: got %f, want 0.85", flashModel.QuotaInfo.RemainingFraction) + } + if flashModel.QuotaInfo.ResetTime != "2025-01-01T00:00:00Z" { + t.Errorf("ResetTime 不匹配: got %s", flashModel.QuotaInfo.ResetTime) + } + + proModel, ok := resp.Models["gemini-2.5-pro"] + if !ok { + t.Fatal("缺少 gemini-2.5-pro 模型") + } + if proModel.QuotaInfo == nil { + t.Fatal("gemini-2.5-pro QuotaInfo 不应为 nil") + } + if proModel.QuotaInfo.RemainingFraction != 0.5 { + t.Errorf("RemainingFraction 不匹配: got %f, want 0.5", proModel.QuotaInfo.RemainingFraction) + } + + // 验证原始 JSON map + if rawResp == nil { + t.Fatal("rawResp 不应为 nil") + } + if rawResp["models"] == nil { + t.Error("rawResp models 不应为 nil") + } +} + +func TestClient_FetchAvailableModels_HTTPError_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"error":"forbidden"}`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := mustNewClient(t, "") + _, _, err := client.FetchAvailableModels(context.Background(), "bad-token", "proj") + if err == nil { + t.Fatal("服务器返回 403 时应返回错误") + } + if !strings.Contains(err.Error(), "fetchAvailableModels 失败") { + t.Errorf("错误信息应包含 'fetchAvailableModels 失败': got %s", err.Error()) + } +} + +func TestClient_FetchAvailableModels_InvalidJSON_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`<<>>`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := mustNewClient(t, "") + _, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") + if err == nil { + t.Fatal("无效 JSON 响应应返回错误") + } + if !strings.Contains(err.Error(), "响应解析失败") { + t.Errorf("错误信息应包含 '响应解析失败': got %s", err.Error()) + } +} + +func TestClient_FetchAvailableModels_URLFallback_RealCall(t *testing.T) { + callCount := 0 + // 第一个 server 返回 429,第二个 server 返回成功 + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"error":"rate_limited"}`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"models": {"model-a": {}}}`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := mustNewClient(t, "") + resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") + if err != nil { + t.Fatalf("FetchAvailableModels 应在 fallback 后成功: %v", err) + } + if _, ok := resp.Models["model-a"]; !ok { + t.Error("应返回 fallback server 的模型") + } + if callCount != 2 { + t.Errorf("应该调用了 2 个 server,实际调用 %d 次", callCount) + } +} + +func TestClient_FetchAvailableModels_AllURLsFail_RealCall(t *testing.T) { + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`not found`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`internal error`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := mustNewClient(t, "") + _, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") + if err == nil { + t.Fatal("所有 URL 都失败时应返回错误") + } +} + +func TestClient_FetchAvailableModels_ContextCanceled_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := mustNewClient(t, "") + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, _, err := client.FetchAvailableModels(ctx, "token", "proj") + if err == nil { + t.Fatal("context 取消时应返回错误") + } +} + +func TestClient_FetchAvailableModels_EmptyModels_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"models": {}}`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := mustNewClient(t, "") + resp, rawResp, err := client.FetchAvailableModels(context.Background(), "token", "proj") + if err != nil { + t.Fatalf("FetchAvailableModels 失败: %v", err) + } + if resp.Models == nil { + t.Fatal("Models 不应为 nil") + } + if len(resp.Models) != 0 { + t.Errorf("Models 应为空: got %d", len(resp.Models)) + } + if rawResp == nil { + t.Fatal("rawResp 不应为 nil") + } +} + +// --------------------------------------------------------------------------- +// LoadCodeAssist 和 FetchAvailableModels 的 408 fallback 测试 +// --------------------------------------------------------------------------- + +func TestClient_LoadCodeAssist_408Fallback_RealCall(t *testing.T) { + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusRequestTimeout) + _, _ = w.Write([]byte(`timeout`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"cloudaicompanionProject":"p2","currentTier":"free-tier"}`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := mustNewClient(t, "") + resp, _, err := client.LoadCodeAssist(context.Background(), "token") + if err != nil { + t.Fatalf("LoadCodeAssist 应在 408 fallback 后成功: %v", err) + } + if resp.CloudAICompanionProject != "p2" { + t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject) + } +} + +func TestClient_FetchAvailableModels_404Fallback_RealCall(t *testing.T) { + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`not found`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"models":{"m1":{"quotaInfo":{"remainingFraction":1.0}}}}`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := mustNewClient(t, "") + resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") + if err != nil { + t.Fatalf("FetchAvailableModels 应在 404 fallback 后成功: %v", err) + } + if _, ok := resp.Models["m1"]; !ok { + t.Error("应返回 fallback server 的模型 m1") + } +} + func TestExtractProjectIDFromOnboardResponse(t *testing.T) { t.Parallel() diff --git a/backend/internal/pkg/antigravity/gemini_types.go b/backend/internal/pkg/antigravity/gemini_types.go index 32495827..0ff24a1f 100644 --- a/backend/internal/pkg/antigravity/gemini_types.go +++ b/backend/internal/pkg/antigravity/gemini_types.go @@ -70,7 +70,7 @@ type GeminiGenerationConfig struct { ImageConfig *GeminiImageConfig `json:"imageConfig,omitempty"` } -// GeminiImageConfig Gemini 图片生成配置(仅 gemini-3-pro-image 支持) +// GeminiImageConfig Gemini 图片生成配置(gemini-3-pro-image / gemini-3.1-flash-image 等图片模型支持) type GeminiImageConfig struct { AspectRatio string `json:"aspectRatio,omitempty"` // "1:1", "16:9", "9:16", "4:3", "3:4" ImageSize string `json:"imageSize,omitempty"` // "1K", "2K", "4K" diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index d1712c98..afffe9b1 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -6,10 +6,14 @@ import ( "encoding/base64" "encoding/hex" "fmt" + "net/http" "net/url" + "os" "strings" "sync" "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) const ( @@ -19,8 +23,10 @@ const ( UserInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo" // Antigravity OAuth 客户端凭证 - ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" - ClientSecret = "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" + ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" + + // AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。 + AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET" // 固定的 redirect_uri(用户需手动复制 code) RedirectURI = "http://localhost:8085/callback" @@ -32,9 +38,6 @@ const ( "https://www.googleapis.com/auth/cclog " + "https://www.googleapis.com/auth/experimentsandconfigs" - // User-Agent(与 Antigravity-Manager 保持一致) - UserAgent = "antigravity/1.15.8 windows/amd64" - // Session 过期时间 SessionTTL = 30 * time.Minute @@ -46,6 +49,35 @@ const ( antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com" ) +// defaultUserAgentVersion 可通过环境变量 ANTIGRAVITY_USER_AGENT_VERSION 配置,默认 1.19.6 +var defaultUserAgentVersion = "1.19.6" + +// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置 +var defaultClientSecret = "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" + +func init() { + // 从环境变量读取版本号,未设置则使用默认值 + if version := os.Getenv("ANTIGRAVITY_USER_AGENT_VERSION"); version != "" { + defaultUserAgentVersion = version + } + // 从环境变量读取 client_secret,未设置则使用默认值 + if secret := os.Getenv(AntigravityOAuthClientSecretEnv); secret != "" { + defaultClientSecret = secret + } +} + +// GetUserAgent 返回当前配置的 User-Agent +func GetUserAgent() string { + return fmt.Sprintf("antigravity/%s windows/amd64", defaultUserAgentVersion) +} + +func getClientSecret() (string, error) { + if v := strings.TrimSpace(defaultClientSecret); v != "" { + return v, nil + } + return "", infraerrors.Newf(http.StatusBadRequest, "ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING", "missing antigravity oauth client_secret; set %s", AntigravityOAuthClientSecretEnv) +} + // BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致) var BaseURLs = []string{ antigravityProdBaseURL, // prod (优先) diff --git a/backend/internal/pkg/antigravity/oauth_test.go b/backend/internal/pkg/antigravity/oauth_test.go new file mode 100644 index 00000000..743e2a33 --- /dev/null +++ b/backend/internal/pkg/antigravity/oauth_test.go @@ -0,0 +1,718 @@ +//go:build unit + +package antigravity + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "net/url" + "os" + "strings" + "testing" + "time" +) + +// --------------------------------------------------------------------------- +// getClientSecret +// --------------------------------------------------------------------------- + +func TestGetClientSecret_环境变量设置(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "" + t.Cleanup(func() { defaultClientSecret = old }) + t.Setenv(AntigravityOAuthClientSecretEnv, "my-secret-value") + + // 需要重新触发 init 逻辑:手动从环境变量读取 + defaultClientSecret = os.Getenv(AntigravityOAuthClientSecretEnv) + + secret, err := getClientSecret() + if err != nil { + t.Fatalf("获取 client_secret 失败: %v", err) + } + if secret != "my-secret-value" { + t.Errorf("client_secret 不匹配: got %s, want my-secret-value", secret) + } +} + +func TestGetClientSecret_环境变量为空(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "" + t.Cleanup(func() { defaultClientSecret = old }) + + _, err := getClientSecret() + if err == nil { + t.Fatal("defaultClientSecret 为空时应返回错误") + } + if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) { + t.Errorf("错误信息应包含环境变量名: got %s", err.Error()) + } +} + +func TestGetClientSecret_环境变量未设置(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = "" + t.Cleanup(func() { defaultClientSecret = old }) + + _, err := getClientSecret() + if err == nil { + t.Fatal("defaultClientSecret 为空时应返回错误") + } +} + +func TestGetClientSecret_环境变量含空格(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = " " + t.Cleanup(func() { defaultClientSecret = old }) + + _, err := getClientSecret() + if err == nil { + t.Fatal("defaultClientSecret 仅含空格时应返回错误") + } +} + +func TestGetClientSecret_环境变量有前后空格(t *testing.T) { + old := defaultClientSecret + defaultClientSecret = " valid-secret " + t.Cleanup(func() { defaultClientSecret = old }) + + secret, err := getClientSecret() + if err != nil { + t.Fatalf("获取 client_secret 失败: %v", err) + } + if secret != "valid-secret" { + t.Errorf("应去除前后空格: got %q, want %q", secret, "valid-secret") + } +} + +// --------------------------------------------------------------------------- +// ForwardBaseURLs +// --------------------------------------------------------------------------- + +func TestForwardBaseURLs_Daily优先(t *testing.T) { + urls := ForwardBaseURLs() + if len(urls) == 0 { + t.Fatal("ForwardBaseURLs 返回空列表") + } + + // daily URL 应排在第一位 + if urls[0] != antigravityDailyBaseURL { + t.Errorf("第一个 URL 应为 daily: got %s, want %s", urls[0], antigravityDailyBaseURL) + } + + // 应包含所有 URL + if len(urls) != len(BaseURLs) { + t.Errorf("URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs)) + } + + // 验证 prod URL 也在列表中 + found := false + for _, u := range urls { + if u == antigravityProdBaseURL { + found = true + break + } + } + if !found { + t.Error("ForwardBaseURLs 中缺少 prod URL") + } +} + +func TestForwardBaseURLs_不修改原切片(t *testing.T) { + originalFirst := BaseURLs[0] + _ = ForwardBaseURLs() + // 确保原始 BaseURLs 未被修改 + if BaseURLs[0] != originalFirst { + t.Errorf("ForwardBaseURLs 不应修改原始 BaseURLs: got %s, want %s", BaseURLs[0], originalFirst) + } +} + +// --------------------------------------------------------------------------- +// URLAvailability +// --------------------------------------------------------------------------- + +func TestNewURLAvailability(t *testing.T) { + ua := NewURLAvailability(5 * time.Minute) + if ua == nil { + t.Fatal("NewURLAvailability 返回 nil") + } + if ua.ttl != 5*time.Minute { + t.Errorf("TTL 不匹配: got %v, want 5m", ua.ttl) + } + if ua.unavailable == nil { + t.Error("unavailable map 不应为 nil") + } +} + +func TestURLAvailability_MarkUnavailable(t *testing.T) { + ua := NewURLAvailability(5 * time.Minute) + testURL := "https://example.com" + + ua.MarkUnavailable(testURL) + + if ua.IsAvailable(testURL) { + t.Error("标记为不可用后 IsAvailable 应返回 false") + } +} + +func TestURLAvailability_MarkSuccess(t *testing.T) { + ua := NewURLAvailability(5 * time.Minute) + testURL := "https://example.com" + + // 先标记为不可用 + ua.MarkUnavailable(testURL) + if ua.IsAvailable(testURL) { + t.Error("标记为不可用后应不可用") + } + + // 标记成功后应恢复可用 + ua.MarkSuccess(testURL) + if !ua.IsAvailable(testURL) { + t.Error("MarkSuccess 后应恢复可用") + } + + // 验证 lastSuccess 被设置 + ua.mu.RLock() + if ua.lastSuccess != testURL { + t.Errorf("lastSuccess 不匹配: got %s, want %s", ua.lastSuccess, testURL) + } + ua.mu.RUnlock() +} + +func TestURLAvailability_IsAvailable_TTL过期(t *testing.T) { + // 使用极短的 TTL + ua := NewURLAvailability(1 * time.Millisecond) + testURL := "https://example.com" + + ua.MarkUnavailable(testURL) + // 等待 TTL 过期 + time.Sleep(5 * time.Millisecond) + + if !ua.IsAvailable(testURL) { + t.Error("TTL 过期后 URL 应恢复可用") + } +} + +func TestURLAvailability_IsAvailable_未标记的URL(t *testing.T) { + ua := NewURLAvailability(5 * time.Minute) + if !ua.IsAvailable("https://never-marked.com") { + t.Error("未标记的 URL 应默认可用") + } +} + +func TestURLAvailability_GetAvailableURLs(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + + // 默认所有 URL 都可用 + urls := ua.GetAvailableURLs() + if len(urls) != len(BaseURLs) { + t.Errorf("可用 URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs)) + } +} + +func TestURLAvailability_GetAvailableURLs_标记一个不可用(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + + if len(BaseURLs) < 2 { + t.Skip("BaseURLs 少于 2 个,跳过此测试") + } + + ua.MarkUnavailable(BaseURLs[0]) + urls := ua.GetAvailableURLs() + + // 标记的 URL 不应出现在可用列表中 + for _, u := range urls { + if u == BaseURLs[0] { + t.Errorf("被标记不可用的 URL 不应出现在可用列表中: %s", BaseURLs[0]) + } + } +} + +func TestURLAvailability_GetAvailableURLsWithBase(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + customURLs := []string{"https://a.com", "https://b.com", "https://c.com"} + + urls := ua.GetAvailableURLsWithBase(customURLs) + if len(urls) != 3 { + t.Errorf("可用 URL 数量不匹配: got %d, want 3", len(urls)) + } +} + +func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess优先(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + customURLs := []string{"https://a.com", "https://b.com", "https://c.com"} + + ua.MarkSuccess("https://c.com") + + urls := ua.GetAvailableURLsWithBase(customURLs) + if len(urls) != 3 { + t.Fatalf("可用 URL 数量不匹配: got %d, want 3", len(urls)) + } + // c.com 应排在第一位 + if urls[0] != "https://c.com" { + t.Errorf("lastSuccess 应排在第一位: got %s, want https://c.com", urls[0]) + } + // 其余按原始顺序 + if urls[1] != "https://a.com" { + t.Errorf("第二个应为 a.com: got %s", urls[1]) + } + if urls[2] != "https://b.com" { + t.Errorf("第三个应为 b.com: got %s", urls[2]) + } +} + +func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不可用(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + customURLs := []string{"https://a.com", "https://b.com"} + + ua.MarkSuccess("https://b.com") + ua.MarkUnavailable("https://b.com") + + urls := ua.GetAvailableURLsWithBase(customURLs) + // b.com 被标记不可用,不应出现 + if len(urls) != 1 { + t.Fatalf("可用 URL 数量不匹配: got %d, want 1", len(urls)) + } + if urls[0] != "https://a.com" { + t.Errorf("仅 a.com 应可用: got %s", urls[0]) + } +} + +func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不在列表中(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + customURLs := []string{"https://a.com", "https://b.com"} + + ua.MarkSuccess("https://not-in-list.com") + + urls := ua.GetAvailableURLsWithBase(customURLs) + // lastSuccess 不在自定义列表中,不应被添加 + if len(urls) != 2 { + t.Fatalf("可用 URL 数量不匹配: got %d, want 2", len(urls)) + } +} + +// --------------------------------------------------------------------------- +// SessionStore +// --------------------------------------------------------------------------- + +func TestNewSessionStore(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + if store == nil { + t.Fatal("NewSessionStore 返回 nil") + } + if store.sessions == nil { + t.Error("sessions map 不应为 nil") + } +} + +func TestSessionStore_SetAndGet(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + session := &OAuthSession{ + State: "test-state", + CodeVerifier: "test-verifier", + ProxyURL: "http://proxy.example.com", + CreatedAt: time.Now(), + } + + store.Set("session-1", session) + + got, ok := store.Get("session-1") + if !ok { + t.Fatal("Get 应返回 true") + } + if got.State != "test-state" { + t.Errorf("State 不匹配: got %s", got.State) + } + if got.CodeVerifier != "test-verifier" { + t.Errorf("CodeVerifier 不匹配: got %s", got.CodeVerifier) + } + if got.ProxyURL != "http://proxy.example.com" { + t.Errorf("ProxyURL 不匹配: got %s", got.ProxyURL) + } +} + +func TestSessionStore_Get_不存在(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + _, ok := store.Get("nonexistent") + if ok { + t.Error("不存在的 session 应返回 false") + } +} + +func TestSessionStore_Get_过期(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + session := &OAuthSession{ + State: "expired-state", + CreatedAt: time.Now().Add(-SessionTTL - time.Minute), // 已过期 + } + + store.Set("expired-session", session) + + _, ok := store.Get("expired-session") + if ok { + t.Error("过期的 session 应返回 false") + } +} + +func TestSessionStore_Delete(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + session := &OAuthSession{ + State: "to-delete", + CreatedAt: time.Now(), + } + + store.Set("del-session", session) + store.Delete("del-session") + + _, ok := store.Get("del-session") + if ok { + t.Error("删除后 Get 应返回 false") + } +} + +func TestSessionStore_Delete_不存在(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + // 删除不存在的 session 不应 panic + store.Delete("nonexistent") +} + +func TestSessionStore_Stop(t *testing.T) { + store := NewSessionStore() + store.Stop() + + // 多次 Stop 不应 panic + store.Stop() +} + +func TestSessionStore_多个Session(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + for i := 0; i < 10; i++ { + session := &OAuthSession{ + State: "state-" + string(rune('0'+i)), + CreatedAt: time.Now(), + } + store.Set("session-"+string(rune('0'+i)), session) + } + + // 验证都能取到 + for i := 0; i < 10; i++ { + _, ok := store.Get("session-" + string(rune('0'+i))) + if !ok { + t.Errorf("session-%d 应存在", i) + } + } +} + +// --------------------------------------------------------------------------- +// GenerateRandomBytes +// --------------------------------------------------------------------------- + +func TestGenerateRandomBytes_长度正确(t *testing.T) { + sizes := []int{0, 1, 16, 32, 64, 128} + for _, size := range sizes { + b, err := GenerateRandomBytes(size) + if err != nil { + t.Fatalf("GenerateRandomBytes(%d) 失败: %v", size, err) + } + if len(b) != size { + t.Errorf("长度不匹配: got %d, want %d", len(b), size) + } + } +} + +func TestGenerateRandomBytes_不同调用产生不同结果(t *testing.T) { + b1, err := GenerateRandomBytes(32) + if err != nil { + t.Fatalf("第一次调用失败: %v", err) + } + b2, err := GenerateRandomBytes(32) + if err != nil { + t.Fatalf("第二次调用失败: %v", err) + } + // 两次生成的随机字节应该不同(概率上几乎不可能相同) + if string(b1) == string(b2) { + t.Error("两次生成的随机字节相同,概率极低,可能有问题") + } +} + +// --------------------------------------------------------------------------- +// GenerateState +// --------------------------------------------------------------------------- + +func TestGenerateState_返回值格式(t *testing.T) { + state, err := GenerateState() + if err != nil { + t.Fatalf("GenerateState 失败: %v", err) + } + if state == "" { + t.Error("GenerateState 返回空字符串") + } + // base64url 编码不应包含 +, /, = + if strings.ContainsAny(state, "+/=") { + t.Errorf("GenerateState 返回值包含非 base64url 字符: %s", state) + } + // 32 字节的 base64url 编码长度应为 43(去掉了尾部 = 填充) + if len(state) != 43 { + t.Errorf("GenerateState 返回值长度不匹配: got %d, want 43", len(state)) + } +} + +func TestGenerateState_唯一性(t *testing.T) { + s1, _ := GenerateState() + s2, _ := GenerateState() + if s1 == s2 { + t.Error("两次 GenerateState 结果相同") + } +} + +// --------------------------------------------------------------------------- +// GenerateSessionID +// --------------------------------------------------------------------------- + +func TestGenerateSessionID_返回值格式(t *testing.T) { + id, err := GenerateSessionID() + if err != nil { + t.Fatalf("GenerateSessionID 失败: %v", err) + } + if id == "" { + t.Error("GenerateSessionID 返回空字符串") + } + // 16 字节的 hex 编码长度应为 32 + if len(id) != 32 { + t.Errorf("GenerateSessionID 返回值长度不匹配: got %d, want 32", len(id)) + } + // 验证是合法的 hex 字符串 + if _, err := hex.DecodeString(id); err != nil { + t.Errorf("GenerateSessionID 返回值不是合法的 hex 字符串: %s, err: %v", id, err) + } +} + +func TestGenerateSessionID_唯一性(t *testing.T) { + id1, _ := GenerateSessionID() + id2, _ := GenerateSessionID() + if id1 == id2 { + t.Error("两次 GenerateSessionID 结果相同") + } +} + +// --------------------------------------------------------------------------- +// GenerateCodeVerifier +// --------------------------------------------------------------------------- + +func TestGenerateCodeVerifier_返回值格式(t *testing.T) { + verifier, err := GenerateCodeVerifier() + if err != nil { + t.Fatalf("GenerateCodeVerifier 失败: %v", err) + } + if verifier == "" { + t.Error("GenerateCodeVerifier 返回空字符串") + } + // base64url 编码不应包含 +, /, = + if strings.ContainsAny(verifier, "+/=") { + t.Errorf("GenerateCodeVerifier 返回值包含非 base64url 字符: %s", verifier) + } + // 32 字节的 base64url 编码长度应为 43 + if len(verifier) != 43 { + t.Errorf("GenerateCodeVerifier 返回值长度不匹配: got %d, want 43", len(verifier)) + } +} + +func TestGenerateCodeVerifier_唯一性(t *testing.T) { + v1, _ := GenerateCodeVerifier() + v2, _ := GenerateCodeVerifier() + if v1 == v2 { + t.Error("两次 GenerateCodeVerifier 结果相同") + } +} + +// --------------------------------------------------------------------------- +// GenerateCodeChallenge +// --------------------------------------------------------------------------- + +func TestGenerateCodeChallenge_SHA256_Base64URL(t *testing.T) { + verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + + challenge := GenerateCodeChallenge(verifier) + + // 手动计算预期值 + hash := sha256.Sum256([]byte(verifier)) + expected := strings.TrimRight(base64.URLEncoding.EncodeToString(hash[:]), "=") + + if challenge != expected { + t.Errorf("CodeChallenge 不匹配: got %s, want %s", challenge, expected) + } +} + +func TestGenerateCodeChallenge_不含填充字符(t *testing.T) { + challenge := GenerateCodeChallenge("test-verifier") + if strings.Contains(challenge, "=") { + t.Errorf("CodeChallenge 不应包含 = 填充字符: %s", challenge) + } +} + +func TestGenerateCodeChallenge_不含非URL安全字符(t *testing.T) { + challenge := GenerateCodeChallenge("another-verifier") + if strings.ContainsAny(challenge, "+/") { + t.Errorf("CodeChallenge 不应包含 + 或 / 字符: %s", challenge) + } +} + +func TestGenerateCodeChallenge_相同输入相同输出(t *testing.T) { + c1 := GenerateCodeChallenge("same-verifier") + c2 := GenerateCodeChallenge("same-verifier") + if c1 != c2 { + t.Errorf("相同输入应产生相同输出: got %s and %s", c1, c2) + } +} + +func TestGenerateCodeChallenge_不同输入不同输出(t *testing.T) { + c1 := GenerateCodeChallenge("verifier-1") + c2 := GenerateCodeChallenge("verifier-2") + if c1 == c2 { + t.Error("不同输入应产生不同输出") + } +} + +// --------------------------------------------------------------------------- +// BuildAuthorizationURL +// --------------------------------------------------------------------------- + +func TestBuildAuthorizationURL_参数验证(t *testing.T) { + state := "test-state-123" + codeChallenge := "test-challenge-abc" + + authURL := BuildAuthorizationURL(state, codeChallenge) + + // 验证以 AuthorizeURL 开头 + if !strings.HasPrefix(authURL, AuthorizeURL+"?") { + t.Errorf("URL 应以 %s? 开头: got %s", AuthorizeURL, authURL) + } + + // 解析 URL 并验证参数 + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("解析 URL 失败: %v", err) + } + + params := parsed.Query() + + expectedParams := map[string]string{ + "client_id": ClientID, + "redirect_uri": RedirectURI, + "response_type": "code", + "scope": Scopes, + "state": state, + "code_challenge": codeChallenge, + "code_challenge_method": "S256", + "access_type": "offline", + "prompt": "consent", + "include_granted_scopes": "true", + } + + for key, want := range expectedParams { + got := params.Get(key) + if got != want { + t.Errorf("参数 %s 不匹配: got %q, want %q", key, got, want) + } + } +} + +func TestBuildAuthorizationURL_参数数量(t *testing.T) { + authURL := BuildAuthorizationURL("s", "c") + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("解析 URL 失败: %v", err) + } + + params := parsed.Query() + // 应包含 10 个参数 + expectedCount := 10 + if len(params) != expectedCount { + t.Errorf("参数数量不匹配: got %d, want %d", len(params), expectedCount) + } +} + +func TestBuildAuthorizationURL_特殊字符编码(t *testing.T) { + state := "state+with/special=chars" + codeChallenge := "challenge+value" + + authURL := BuildAuthorizationURL(state, codeChallenge) + + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("解析 URL 失败: %v", err) + } + + // 解析后应正确还原特殊字符 + if got := parsed.Query().Get("state"); got != state { + t.Errorf("state 参数编码/解码不匹配: got %q, want %q", got, state) + } +} + +// --------------------------------------------------------------------------- +// 常量值验证 +// --------------------------------------------------------------------------- + +func TestConstants_值正确(t *testing.T) { + if AuthorizeURL != "https://accounts.google.com/o/oauth2/v2/auth" { + t.Errorf("AuthorizeURL 不匹配: got %s", AuthorizeURL) + } + if TokenURL != "https://oauth2.googleapis.com/token" { + t.Errorf("TokenURL 不匹配: got %s", TokenURL) + } + if UserInfoURL != "https://www.googleapis.com/oauth2/v2/userinfo" { + t.Errorf("UserInfoURL 不匹配: got %s", UserInfoURL) + } + if ClientID != "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" { + t.Errorf("ClientID 不匹配: got %s", ClientID) + } + secret, err := getClientSecret() + if err != nil { + t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err) + } + if secret != "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" { + t.Errorf("默认 client_secret 不匹配: got %s", secret) + } + if RedirectURI != "http://localhost:8085/callback" { + t.Errorf("RedirectURI 不匹配: got %s", RedirectURI) + } + if GetUserAgent() != "antigravity/1.19.6 windows/amd64" { + t.Errorf("UserAgent 不匹配: got %s", GetUserAgent()) + } + if SessionTTL != 30*time.Minute { + t.Errorf("SessionTTL 不匹配: got %v", SessionTTL) + } + if URLAvailabilityTTL != 5*time.Minute { + t.Errorf("URLAvailabilityTTL 不匹配: got %v", URLAvailabilityTTL) + } +} + +func TestScopes_包含必要范围(t *testing.T) { + expectedScopes := []string{ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + "https://www.googleapis.com/auth/cclog", + "https://www.googleapis.com/auth/experimentsandconfigs", + } + + for _, scope := range expectedScopes { + if !strings.Contains(Scopes, scope) { + t.Errorf("Scopes 缺少 %s", scope) + } + } +} diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index 3ba04b95..55cdd786 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -206,6 +206,7 @@ type modelInfo struct { var modelInfoMap = map[string]modelInfo{ "claude-opus-4-5": {DisplayName: "Claude Opus 4.5", CanonicalID: "claude-opus-4-5-20250929"}, "claude-opus-4-6": {DisplayName: "Claude Opus 4.6", CanonicalID: "claude-opus-4-6"}, + "claude-sonnet-4-6": {DisplayName: "Claude Sonnet 4.6", CanonicalID: "claude-sonnet-4-6"}, "claude-sonnet-4-5": {DisplayName: "Claude Sonnet 4.5", CanonicalID: "claude-sonnet-4-5-20250929"}, "claude-haiku-4-5": {DisplayName: "Claude Haiku 4.5", CanonicalID: "claude-haiku-4-5-20251001"}, } diff --git a/backend/internal/pkg/antigravity/response_transformer.go b/backend/internal/pkg/antigravity/response_transformer.go index 463033f1..f12effb6 100644 --- a/backend/internal/pkg/antigravity/response_transformer.go +++ b/backend/internal/pkg/antigravity/response_transformer.go @@ -1,10 +1,13 @@ package antigravity import ( + "crypto/rand" "encoding/json" "fmt" "log" "strings" + "sync/atomic" + "time" ) // TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式) @@ -341,12 +344,30 @@ func buildGroundingText(grounding *GeminiGroundingMetadata) string { return builder.String() } -// generateRandomID 生成随机 ID +// fallbackCounter 降级伪随机 ID 的全局计数器,混入 seed 避免高并发下 UnixNano 相同导致碰撞。 +var fallbackCounter uint64 + +// generateRandomID 生成密码学安全的随机 ID func generateRandomID() string { const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - result := make([]byte, 12) - for i := range result { - result[i] = chars[i%len(chars)] + id := make([]byte, 12) + randBytes := make([]byte, 12) + if _, err := rand.Read(randBytes); err != nil { + // 避免在请求路径里 panic:极端情况下熵源不可用时降级为伪随机。 + // 这里主要用于生成响应/工具调用的临时 ID,安全要求不高但需尽量避免碰撞。 + cnt := atomic.AddUint64(&fallbackCounter, 1) + seed := uint64(time.Now().UnixNano()) ^ cnt + seed ^= uint64(len(err.Error())) << 32 + for i := range id { + seed ^= seed << 13 + seed ^= seed >> 7 + seed ^= seed << 17 + id[i] = chars[int(seed)%len(chars)] + } + return string(id) } - return string(result) + for i, b := range randBytes { + id[i] = chars[int(b)%len(chars)] + } + return string(id) } diff --git a/backend/internal/pkg/antigravity/response_transformer_test.go b/backend/internal/pkg/antigravity/response_transformer_test.go new file mode 100644 index 00000000..da402b17 --- /dev/null +++ b/backend/internal/pkg/antigravity/response_transformer_test.go @@ -0,0 +1,109 @@ +//go:build unit + +package antigravity + +import ( + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- Task 7: 验证 generateRandomID 和降级碰撞防护 --- + +func TestGenerateRandomID_Uniqueness(t *testing.T) { + seen := make(map[string]struct{}, 100) + for i := 0; i < 100; i++ { + id := generateRandomID() + require.Len(t, id, 12, "ID 长度应为 12") + _, dup := seen[id] + require.False(t, dup, "第 %d 次调用生成了重复 ID: %s", i, id) + seen[id] = struct{}{} + } +} + +func TestFallbackCounter_Increments(t *testing.T) { + // 验证 fallbackCounter 的原子递增行为确保降级分支不会生成相同 seed + before := atomic.LoadUint64(&fallbackCounter) + cnt1 := atomic.AddUint64(&fallbackCounter, 1) + cnt2 := atomic.AddUint64(&fallbackCounter, 1) + require.Equal(t, before+1, cnt1, "第一次递增应为 before+1") + require.Equal(t, before+2, cnt2, "第二次递增应为 before+2") + require.NotEqual(t, cnt1, cnt2, "连续两次递增的计数器值应不同") +} + +func TestFallbackCounter_ConcurrentIncrements(t *testing.T) { + // 验证并发递增的原子性 — 每次递增都应产生唯一值 + const goroutines = 50 + results := make([]uint64, goroutines) + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + results[idx] = atomic.AddUint64(&fallbackCounter, 1) + }(i) + } + wg.Wait() + + // 所有结果应唯一 + seen := make(map[uint64]bool, goroutines) + for _, v := range results { + assert.False(t, seen[v], "并发递增产生了重复值: %d", v) + seen[v] = true + } +} + +func TestGenerateRandomID_Charset(t *testing.T) { + const validChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + validSet := make(map[byte]struct{}, len(validChars)) + for i := 0; i < len(validChars); i++ { + validSet[validChars[i]] = struct{}{} + } + + for i := 0; i < 50; i++ { + id := generateRandomID() + for j := 0; j < len(id); j++ { + _, ok := validSet[id[j]] + require.True(t, ok, "ID 包含非法字符: %c (ID=%s)", id[j], id) + } + } +} + +func TestGenerateRandomID_Length(t *testing.T) { + for i := 0; i < 100; i++ { + id := generateRandomID() + assert.Len(t, id, 12, "每次生成的 ID 长度应为 12") + } +} + +func TestGenerateRandomID_ConcurrentUniqueness(t *testing.T) { + // 验证并发调用不会产生重复 ID + const goroutines = 100 + results := make([]string, goroutines) + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + results[idx] = generateRandomID() + }(i) + } + wg.Wait() + + seen := make(map[string]bool, goroutines) + for _, id := range results { + assert.False(t, seen[id], "并发调用产生了重复 ID: %s", id) + seen[id] = true + } +} + +func BenchmarkGenerateRandomID(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = generateRandomID() + } +} diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index eecee11e..22405382 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -10,8 +10,14 @@ const ( BetaInterleavedThinking = "interleaved-thinking-2025-05-14" BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14" BetaTokenCounting = "token-counting-2024-11-01" + BetaContext1M = "context-1m-2025-08-07" + BetaFastMode = "fast-mode-2026-02-01" ) +// DroppedBetas 是转发时需要从 anthropic-beta header 中移除的 beta token 列表。 +// 这些 token 是客户端特有的,不应透传给上游 API。 +var DroppedBetas = []string{BetaContext1M, BetaFastMode} + // DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming @@ -77,6 +83,12 @@ var DefaultModels = []Model{ DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-06T00:00:00Z", }, + { + ID: "claude-sonnet-4-6", + Type: "model", + DisplayName: "Claude Sonnet 4.6", + CreatedAt: "2026-02-18T00:00:00Z", + }, { ID: "claude-sonnet-4-5-20250929", Type: "model", diff --git a/backend/internal/pkg/ctxkey/ctxkey.go b/backend/internal/pkg/ctxkey/ctxkey.go index 0c4d82f7..25782c55 100644 --- a/backend/internal/pkg/ctxkey/ctxkey.go +++ b/backend/internal/pkg/ctxkey/ctxkey.go @@ -8,9 +8,21 @@ const ( // ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置 ForcePlatform Key = "ctx_force_platform" + // RequestID 为服务端生成/透传的请求 ID。 + RequestID Key = "ctx_request_id" + // ClientRequestID 客户端请求的唯一标识,用于追踪请求全生命周期(用于 Ops 监控与排障)。 ClientRequestID Key = "ctx_client_request_id" + // Model 请求模型标识(用于统一请求链路日志字段)。 + Model Key = "ctx_model" + + // Platform 当前请求最终命中的平台(用于统一请求链路日志字段)。 + Platform Key = "ctx_platform" + + // AccountID 当前请求最终命中的账号 ID(用于统一请求链路日志字段)。 + AccountID Key = "ctx_account_id" + // RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。 RetryCount Key = "ctx_retry_count" @@ -32,4 +44,15 @@ const ( // SingleAccountRetry 标识当前请求处于单账号 503 退避重试模式。 // 在此模式下,Service 层的模型限流预检查将等待限流过期而非直接切换账号。 SingleAccountRetry Key = "ctx_single_account_retry" + + // PrefetchedStickyAccountID 标识上游(通常 handler)预取到的 sticky session 账号 ID。 + // Service 层可复用该值,避免同请求链路重复读取 Redis。 + PrefetchedStickyAccountID Key = "ctx_prefetched_sticky_account_id" + + // PrefetchedStickyGroupID 标识上游预取 sticky session 时所使用的分组 ID。 + // Service 层仅在分组匹配时复用 PrefetchedStickyAccountID,避免分组切换重试误用旧 sticky。 + PrefetchedStickyGroupID Key = "ctx_prefetched_sticky_group_id" + + // ClaudeCodeVersion stores the extracted Claude Code version from User-Agent (e.g. "2.1.22") + ClaudeCodeVersion Key = "ctx_claude_code_version" ) diff --git a/backend/internal/pkg/errors/errors_test.go b/backend/internal/pkg/errors/errors_test.go index 1a1c842e..25e62907 100644 --- a/backend/internal/pkg/errors/errors_test.go +++ b/backend/internal/pkg/errors/errors_test.go @@ -166,3 +166,18 @@ func TestToHTTP(t *testing.T) { }) } } + +func TestToHTTP_MetadataDeepCopy(t *testing.T) { + md := map[string]string{"k": "v"} + appErr := BadRequest("BAD_REQUEST", "invalid").WithMetadata(md) + + code, body := ToHTTP(appErr) + require.Equal(t, http.StatusBadRequest, code) + require.Equal(t, "v", body.Metadata["k"]) + + md["k"] = "changed" + require.Equal(t, "v", body.Metadata["k"]) + + appErr.Metadata["k"] = "changed-again" + require.Equal(t, "v", body.Metadata["k"]) +} diff --git a/backend/internal/pkg/errors/http.go b/backend/internal/pkg/errors/http.go index 7b5560e3..420c69a3 100644 --- a/backend/internal/pkg/errors/http.go +++ b/backend/internal/pkg/errors/http.go @@ -16,6 +16,16 @@ func ToHTTP(err error) (statusCode int, body Status) { return http.StatusOK, Status{Code: int32(http.StatusOK)} } - cloned := Clone(appErr) - return int(cloned.Code), cloned.Status + body = Status{ + Code: appErr.Code, + Reason: appErr.Reason, + Message: appErr.Message, + } + if appErr.Metadata != nil { + body.Metadata = make(map[string]string, len(appErr.Metadata)) + for k, v := range appErr.Metadata { + body.Metadata[k] = v + } + } + return int(appErr.Code), body } diff --git a/backend/internal/pkg/gemini/models.go b/backend/internal/pkg/gemini/models.go index 424e8ddb..c300b17d 100644 --- a/backend/internal/pkg/gemini/models.go +++ b/backend/internal/pkg/gemini/models.go @@ -21,6 +21,7 @@ func DefaultModels() []Model { {Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods}, {Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods}, + {Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods}, } } diff --git a/backend/internal/pkg/geminicli/constants.go b/backend/internal/pkg/geminicli/constants.go index d4d52116..97234ffd 100644 --- a/backend/internal/pkg/geminicli/constants.go +++ b/backend/internal/pkg/geminicli/constants.go @@ -41,6 +41,9 @@ const ( GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" GeminiCLIOAuthClientSecret = "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" + // GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret. + GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET" + SessionTTL = 30 * time.Minute // GeminiCLIUserAgent mimics Gemini CLI to maximize compatibility with internal endpoints. diff --git a/backend/internal/pkg/geminicli/models.go b/backend/internal/pkg/geminicli/models.go index 08e69886..1fc4d983 100644 --- a/backend/internal/pkg/geminicli/models.go +++ b/backend/internal/pkg/geminicli/models.go @@ -16,6 +16,7 @@ var DefaultModels = []Model{ {ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""}, {ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""}, {ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""}, + {ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""}, } // DefaultTestModel is the default model to preselect in test flows. diff --git a/backend/internal/pkg/geminicli/oauth.go b/backend/internal/pkg/geminicli/oauth.go index c71e8aad..b10b5750 100644 --- a/backend/internal/pkg/geminicli/oauth.go +++ b/backend/internal/pkg/geminicli/oauth.go @@ -6,10 +6,14 @@ import ( "encoding/base64" "encoding/hex" "fmt" + "net/http" "net/url" + "os" "strings" "sync" "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) type OAuthConfig struct { @@ -164,15 +168,24 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error } // Fall back to built-in Gemini CLI OAuth client when not configured. + // SECURITY: This repo does not embed the built-in client secret; it must be provided via env. if effective.ClientID == "" && effective.ClientSecret == "" { + secret := strings.TrimSpace(GeminiCLIOAuthClientSecret) + if secret == "" { + if v, ok := os.LookupEnv(GeminiCLIOAuthClientSecretEnv); ok { + secret = strings.TrimSpace(v) + } + } + if secret == "" { + return OAuthConfig{}, infraerrors.Newf(http.StatusBadRequest, "GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING", "built-in Gemini CLI OAuth client_secret is not configured; set %s or provide a custom OAuth client", GeminiCLIOAuthClientSecretEnv) + } effective.ClientID = GeminiCLIOAuthClientID - effective.ClientSecret = GeminiCLIOAuthClientSecret + effective.ClientSecret = secret } else if effective.ClientID == "" || effective.ClientSecret == "" { - return OAuthConfig{}, fmt.Errorf("OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)") + return OAuthConfig{}, infraerrors.New(http.StatusBadRequest, "GEMINI_OAUTH_CLIENT_NOT_CONFIGURED", "OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)") } - isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID && - effective.ClientSecret == GeminiCLIOAuthClientSecret + isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID if effective.Scopes == "" { // Use different default scopes based on OAuth type diff --git a/backend/internal/pkg/geminicli/oauth_test.go b/backend/internal/pkg/geminicli/oauth_test.go index 0770730a..2a430f9e 100644 --- a/backend/internal/pkg/geminicli/oauth_test.go +++ b/backend/internal/pkg/geminicli/oauth_test.go @@ -1,11 +1,441 @@ package geminicli import ( + "encoding/hex" "strings" + "sync" "testing" + "time" ) +// --------------------------------------------------------------------------- +// SessionStore 测试 +// --------------------------------------------------------------------------- + +func TestSessionStore_SetAndGet(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + session := &OAuthSession{ + State: "test-state", + OAuthType: "code_assist", + CreatedAt: time.Now(), + } + store.Set("sid-1", session) + + got, ok := store.Get("sid-1") + if !ok { + t.Fatal("期望 Get 返回 ok=true,实际返回 false") + } + if got.State != "test-state" { + t.Errorf("期望 State=%q,实际=%q", "test-state", got.State) + } +} + +func TestSessionStore_GetNotFound(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + _, ok := store.Get("不存在的ID") + if ok { + t.Error("期望不存在的 sessionID 返回 ok=false") + } +} + +func TestSessionStore_GetExpired(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + // 创建一个已过期的 session(CreatedAt 设置为 SessionTTL+1 分钟之前) + session := &OAuthSession{ + State: "expired-state", + OAuthType: "code_assist", + CreatedAt: time.Now().Add(-(SessionTTL + 1*time.Minute)), + } + store.Set("expired-sid", session) + + _, ok := store.Get("expired-sid") + if ok { + t.Error("期望过期的 session 返回 ok=false") + } +} + +func TestSessionStore_Delete(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + session := &OAuthSession{ + State: "to-delete", + OAuthType: "code_assist", + CreatedAt: time.Now(), + } + store.Set("del-sid", session) + + // 先确认存在 + if _, ok := store.Get("del-sid"); !ok { + t.Fatal("删除前 session 应该存在") + } + + store.Delete("del-sid") + + if _, ok := store.Get("del-sid"); ok { + t.Error("删除后 session 不应该存在") + } +} + +func TestSessionStore_Stop_Idempotent(t *testing.T) { + store := NewSessionStore() + + // 多次调用 Stop 不应 panic + store.Stop() + store.Stop() + store.Stop() +} + +func TestSessionStore_ConcurrentAccess(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + const goroutines = 50 + var wg sync.WaitGroup + wg.Add(goroutines * 3) + + // 并发写入 + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + sid := "concurrent-" + string(rune('A'+idx%26)) + store.Set(sid, &OAuthSession{ + State: sid, + OAuthType: "code_assist", + CreatedAt: time.Now(), + }) + }(i) + } + + // 并发读取 + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + sid := "concurrent-" + string(rune('A'+idx%26)) + store.Get(sid) // 可能找到也可能没找到,关键是不 panic + }(i) + } + + // 并发删除 + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + sid := "concurrent-" + string(rune('A'+idx%26)) + store.Delete(sid) + }(i) + } + + wg.Wait() +} + +// --------------------------------------------------------------------------- +// GenerateRandomBytes 测试 +// --------------------------------------------------------------------------- + +func TestGenerateRandomBytes(t *testing.T) { + tests := []int{0, 1, 16, 32, 64} + for _, n := range tests { + b, err := GenerateRandomBytes(n) + if err != nil { + t.Errorf("GenerateRandomBytes(%d) 出错: %v", n, err) + continue + } + if len(b) != n { + t.Errorf("GenerateRandomBytes(%d) 返回长度=%d,期望=%d", n, len(b), n) + } + } +} + +func TestGenerateRandomBytes_Uniqueness(t *testing.T) { + // 两次调用应该返回不同的结果(极小概率相同,32字节足够) + a, _ := GenerateRandomBytes(32) + b, _ := GenerateRandomBytes(32) + if string(a) == string(b) { + t.Error("两次 GenerateRandomBytes(32) 返回了相同结果,随机性可能有问题") + } +} + +// --------------------------------------------------------------------------- +// GenerateState 测试 +// --------------------------------------------------------------------------- + +func TestGenerateState(t *testing.T) { + state, err := GenerateState() + if err != nil { + t.Fatalf("GenerateState() 出错: %v", err) + } + if state == "" { + t.Error("GenerateState() 返回空字符串") + } + // base64url 编码不应包含 padding '=' + if strings.Contains(state, "=") { + t.Errorf("GenerateState() 结果包含 '=' padding: %s", state) + } + // base64url 不应包含 '+' 或 '/' + if strings.ContainsAny(state, "+/") { + t.Errorf("GenerateState() 结果包含非 base64url 字符: %s", state) + } +} + +// --------------------------------------------------------------------------- +// GenerateSessionID 测试 +// --------------------------------------------------------------------------- + +func TestGenerateSessionID(t *testing.T) { + sid, err := GenerateSessionID() + if err != nil { + t.Fatalf("GenerateSessionID() 出错: %v", err) + } + // 16 字节 -> 32 个 hex 字符 + if len(sid) != 32 { + t.Errorf("GenerateSessionID() 长度=%d,期望=32", len(sid)) + } + // 必须是合法的 hex 字符串 + if _, err := hex.DecodeString(sid); err != nil { + t.Errorf("GenerateSessionID() 不是合法的 hex 字符串: %s, err=%v", sid, err) + } +} + +func TestGenerateSessionID_Uniqueness(t *testing.T) { + a, _ := GenerateSessionID() + b, _ := GenerateSessionID() + if a == b { + t.Error("两次 GenerateSessionID() 返回了相同结果") + } +} + +// --------------------------------------------------------------------------- +// GenerateCodeVerifier 测试 +// --------------------------------------------------------------------------- + +func TestGenerateCodeVerifier(t *testing.T) { + verifier, err := GenerateCodeVerifier() + if err != nil { + t.Fatalf("GenerateCodeVerifier() 出错: %v", err) + } + if verifier == "" { + t.Error("GenerateCodeVerifier() 返回空字符串") + } + // RFC 7636 要求 code_verifier 至少 43 个字符 + if len(verifier) < 43 { + t.Errorf("GenerateCodeVerifier() 长度=%d,RFC 7636 要求至少 43 字符", len(verifier)) + } + // base64url 编码不应包含 padding 和非 URL 安全字符 + if strings.Contains(verifier, "=") { + t.Errorf("GenerateCodeVerifier() 包含 '=' padding: %s", verifier) + } + if strings.ContainsAny(verifier, "+/") { + t.Errorf("GenerateCodeVerifier() 包含非 base64url 字符: %s", verifier) + } +} + +// --------------------------------------------------------------------------- +// GenerateCodeChallenge 测试 +// --------------------------------------------------------------------------- + +func TestGenerateCodeChallenge(t *testing.T) { + // 使用已知输入验证输出 + // RFC 7636 附录 B 示例: verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + // 预期 challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" + verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + expected := "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" + + challenge := GenerateCodeChallenge(verifier) + if challenge != expected { + t.Errorf("GenerateCodeChallenge(%q) = %q,期望 %q", verifier, challenge, expected) + } +} + +func TestGenerateCodeChallenge_NoPadding(t *testing.T) { + challenge := GenerateCodeChallenge("test-verifier-string") + if strings.Contains(challenge, "=") { + t.Errorf("GenerateCodeChallenge() 结果包含 '=' padding: %s", challenge) + } +} + +// --------------------------------------------------------------------------- +// base64URLEncode 测试 +// --------------------------------------------------------------------------- + +func TestBase64URLEncode(t *testing.T) { + tests := []struct { + name string + input []byte + }{ + {"空字节", []byte{}}, + {"单字节", []byte{0xff}}, + {"多字节", []byte{0x01, 0x02, 0x03, 0x04, 0x05}}, + {"全零", []byte{0x00, 0x00, 0x00}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := base64URLEncode(tt.input) + // 不应包含 '=' padding + if strings.Contains(result, "=") { + t.Errorf("base64URLEncode(%v) 包含 '=' padding: %s", tt.input, result) + } + // 不应包含标准 base64 的 '+' 或 '/' + if strings.ContainsAny(result, "+/") { + t.Errorf("base64URLEncode(%v) 包含非 URL 安全字符: %s", tt.input, result) + } + }) + } +} + +// --------------------------------------------------------------------------- +// hasRestrictedScope 测试 +// --------------------------------------------------------------------------- + +func TestHasRestrictedScope(t *testing.T) { + tests := []struct { + scope string + expected bool + }{ + // 受限 scope + {"https://www.googleapis.com/auth/generative-language", true}, + {"https://www.googleapis.com/auth/generative-language.retriever", true}, + {"https://www.googleapis.com/auth/generative-language.tuning", true}, + {"https://www.googleapis.com/auth/drive", true}, + {"https://www.googleapis.com/auth/drive.readonly", true}, + {"https://www.googleapis.com/auth/drive.file", true}, + // 非受限 scope + {"https://www.googleapis.com/auth/cloud-platform", false}, + {"https://www.googleapis.com/auth/userinfo.email", false}, + {"https://www.googleapis.com/auth/userinfo.profile", false}, + // 边界情况 + {"", false}, + {"random-scope", false}, + } + for _, tt := range tests { + t.Run(tt.scope, func(t *testing.T) { + got := hasRestrictedScope(tt.scope) + if got != tt.expected { + t.Errorf("hasRestrictedScope(%q) = %v,期望 %v", tt.scope, got, tt.expected) + } + }) + } +} + +// --------------------------------------------------------------------------- +// BuildAuthorizationURL 测试 +// --------------------------------------------------------------------------- + +func TestBuildAuthorizationURL(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret") + + authURL, err := BuildAuthorizationURL( + OAuthConfig{}, + "test-state", + "test-challenge", + "https://example.com/callback", + "", + "code_assist", + ) + if err != nil { + t.Fatalf("BuildAuthorizationURL() 出错: %v", err) + } + + // 检查返回的 URL 包含期望的参数 + checks := []string{ + "response_type=code", + "client_id=" + GeminiCLIOAuthClientID, + "redirect_uri=", + "state=test-state", + "code_challenge=test-challenge", + "code_challenge_method=S256", + "access_type=offline", + "prompt=consent", + "include_granted_scopes=true", + } + for _, check := range checks { + if !strings.Contains(authURL, check) { + t.Errorf("BuildAuthorizationURL() URL 缺少参数 %q\nURL: %s", check, authURL) + } + } + + // 不应包含 project_id(因为传的是空字符串) + if strings.Contains(authURL, "project_id=") { + t.Errorf("BuildAuthorizationURL() 空 projectID 时不应包含 project_id 参数") + } + + // URL 应该以正确的授权端点开头 + if !strings.HasPrefix(authURL, AuthorizeURL+"?") { + t.Errorf("BuildAuthorizationURL() URL 应以 %s? 开头,实际: %s", AuthorizeURL, authURL) + } +} + +func TestBuildAuthorizationURL_EmptyRedirectURI(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret") + + _, err := BuildAuthorizationURL( + OAuthConfig{}, + "test-state", + "test-challenge", + "", // 空 redirectURI + "", + "code_assist", + ) + if err == nil { + t.Error("BuildAuthorizationURL() 空 redirectURI 应该报错") + } + if !strings.Contains(err.Error(), "redirect_uri") { + t.Errorf("错误消息应包含 'redirect_uri',实际: %v", err) + } +} + +func TestBuildAuthorizationURL_WithProjectID(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret") + + authURL, err := BuildAuthorizationURL( + OAuthConfig{}, + "test-state", + "test-challenge", + "https://example.com/callback", + "my-project-123", + "code_assist", + ) + if err != nil { + t.Fatalf("BuildAuthorizationURL() 出错: %v", err) + } + if !strings.Contains(authURL, "project_id=my-project-123") { + t.Errorf("BuildAuthorizationURL() 带 projectID 时应包含 project_id 参数\nURL: %s", authURL) + } +} + +func TestBuildAuthorizationURL_UsesBuiltinSecretFallback(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "") + + authURL, err := BuildAuthorizationURL( + OAuthConfig{}, + "test-state", + "test-challenge", + "https://example.com/callback", + "", + "code_assist", + ) + if err != nil { + t.Fatalf("BuildAuthorizationURL() 不应报错: %v", err) + } + if !strings.Contains(authURL, "client_id="+GeminiCLIOAuthClientID) { + t.Errorf("应使用内置 Gemini CLI client_id,实际 URL: %s", authURL) + } +} + +// --------------------------------------------------------------------------- +// EffectiveOAuthConfig 测试 - 原有测试 +// --------------------------------------------------------------------------- + func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { + // 内置的 Gemini CLI client secret 不嵌入在此仓库中。 + // 测试通过环境变量设置一个假的 secret 来模拟运维配置。 + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + tests := []struct { name string input OAuthConfig @@ -15,7 +445,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { wantErr bool }{ { - name: "Google One with built-in client (empty config)", + name: "Google One 使用内置客户端(空配置)", input: OAuthConfig{}, oauthType: "google_one", wantClientID: GeminiCLIOAuthClientID, @@ -23,18 +453,18 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { wantErr: false, }, { - name: "Google One always uses built-in client (even if custom credentials passed)", + name: "Google One 使用自定义客户端(传入自定义凭据时使用自定义)", input: OAuthConfig{ ClientID: "custom-client-id", ClientSecret: "custom-client-secret", }, oauthType: "google_one", wantClientID: "custom-client-id", - wantScopes: DefaultCodeAssistScopes, // Uses code assist scopes even with custom client + wantScopes: DefaultCodeAssistScopes, wantErr: false, }, { - name: "Google One with built-in client and custom scopes (should filter restricted scopes)", + name: "Google One 内置客户端 + 自定义 scopes(应过滤受限 scopes)", input: OAuthConfig{ Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly", }, @@ -44,7 +474,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { wantErr: false, }, { - name: "Google One with built-in client and only restricted scopes (should fallback to default)", + name: "Google One 内置客户端 + 仅受限 scopes(应回退到默认)", input: OAuthConfig{ Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly", }, @@ -54,7 +484,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { wantErr: false, }, { - name: "Code Assist with built-in client", + name: "Code Assist 使用内置客户端", input: OAuthConfig{}, oauthType: "code_assist", wantClientID: GeminiCLIOAuthClientID, @@ -84,7 +514,9 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { } func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) { - // Test that Google One with built-in client filters out restricted scopes + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // 测试 Google One + 内置客户端过滤受限 scopes cfg, err := EffectiveOAuthConfig(OAuthConfig{ Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.profile", }, "google_one") @@ -93,21 +525,242 @@ func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) { t.Fatalf("EffectiveOAuthConfig() error = %v", err) } - // Should only contain cloud-platform, userinfo.email, and userinfo.profile - // Should NOT contain generative-language or drive scopes + // 应仅包含 cloud-platform、userinfo.email 和 userinfo.profile + // 不应包含 generative-language 或 drive scopes if strings.Contains(cfg.Scopes, "generative-language") { - t.Errorf("Scopes should not contain generative-language when using built-in client, got: %v", cfg.Scopes) + t.Errorf("使用内置客户端时 Scopes 不应包含 generative-language,实际: %v", cfg.Scopes) } if strings.Contains(cfg.Scopes, "drive") { - t.Errorf("Scopes should not contain drive when using built-in client, got: %v", cfg.Scopes) + t.Errorf("使用内置客户端时 Scopes 不应包含 drive,实际: %v", cfg.Scopes) } if !strings.Contains(cfg.Scopes, "cloud-platform") { - t.Errorf("Scopes should contain cloud-platform, got: %v", cfg.Scopes) + t.Errorf("Scopes 应包含 cloud-platform,实际: %v", cfg.Scopes) } if !strings.Contains(cfg.Scopes, "userinfo.email") { - t.Errorf("Scopes should contain userinfo.email, got: %v", cfg.Scopes) + t.Errorf("Scopes 应包含 userinfo.email,实际: %v", cfg.Scopes) } if !strings.Contains(cfg.Scopes, "userinfo.profile") { - t.Errorf("Scopes should contain userinfo.profile, got: %v", cfg.Scopes) + t.Errorf("Scopes 应包含 userinfo.profile,实际: %v", cfg.Scopes) + } +} + +// --------------------------------------------------------------------------- +// EffectiveOAuthConfig 测试 - 新增分支覆盖 +// --------------------------------------------------------------------------- + +func TestEffectiveOAuthConfig_OnlyClientID_NoSecret(t *testing.T) { + // 只提供 clientID 不提供 secret 应报错 + _, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "some-client-id", + }, "code_assist") + if err == nil { + t.Error("只提供 ClientID 不提供 ClientSecret 应该报错") + } + if !strings.Contains(err.Error(), "client_id") || !strings.Contains(err.Error(), "client_secret") { + t.Errorf("错误消息应提及 client_id 和 client_secret,实际: %v", err) + } +} + +func TestEffectiveOAuthConfig_OnlyClientSecret_NoID(t *testing.T) { + // 只提供 secret 不提供 clientID 应报错 + _, err := EffectiveOAuthConfig(OAuthConfig{ + ClientSecret: "some-client-secret", + }, "code_assist") + if err == nil { + t.Error("只提供 ClientSecret 不提供 ClientID 应该报错") + } + if !strings.Contains(err.Error(), "client_id") || !strings.Contains(err.Error(), "client_secret") { + t.Errorf("错误消息应提及 client_id 和 client_secret,实际: %v", err) + } +} + +func TestEffectiveOAuthConfig_AIStudio_DefaultScopes_BuiltinClient(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // ai_studio 类型,使用内置客户端,scopes 为空 -> 应使用 DefaultCodeAssistScopes(因为内置客户端不能请求 generative-language scope) + cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "ai_studio") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if cfg.Scopes != DefaultCodeAssistScopes { + t.Errorf("ai_studio + 内置客户端应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_AIStudio_DefaultScopes_CustomClient(t *testing.T) { + // ai_studio 类型,使用自定义客户端,scopes 为空 -> 应使用 DefaultAIStudioScopes + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + }, "ai_studio") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if cfg.Scopes != DefaultAIStudioScopes { + t.Errorf("ai_studio + 自定义客户端应使用 DefaultAIStudioScopes,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_AIStudio_ScopeNormalization(t *testing.T) { + // ai_studio 类型,旧的 generative-language scope 应被归一化为 generative-language.retriever + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + Scopes: "https://www.googleapis.com/auth/generative-language https://www.googleapis.com/auth/cloud-platform", + }, "ai_studio") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if strings.Contains(cfg.Scopes, "auth/generative-language ") || strings.HasSuffix(cfg.Scopes, "auth/generative-language") { + // 确保不包含未归一化的旧 scope(仅 generative-language 而非 generative-language.retriever) + parts := strings.Fields(cfg.Scopes) + for _, p := range parts { + if p == "https://www.googleapis.com/auth/generative-language" { + t.Errorf("ai_studio 应将 generative-language 归一化为 generative-language.retriever,实际 scopes: %q", cfg.Scopes) + } + } + } + if !strings.Contains(cfg.Scopes, "generative-language.retriever") { + t.Errorf("ai_studio 归一化后应包含 generative-language.retriever,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_CommaSeparatedScopes(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // 逗号分隔的 scopes 应被归一化为空格分隔 + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + Scopes: "https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/userinfo.email", + }, "code_assist") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + // 应该用空格分隔,而非逗号 + if strings.Contains(cfg.Scopes, ",") { + t.Errorf("逗号分隔的 scopes 应被归一化为空格分隔,实际: %q", cfg.Scopes) + } + if !strings.Contains(cfg.Scopes, "cloud-platform") { + t.Errorf("归一化后应包含 cloud-platform,实际: %q", cfg.Scopes) + } + if !strings.Contains(cfg.Scopes, "userinfo.email") { + t.Errorf("归一化后应包含 userinfo.email,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_MixedCommaAndSpaceScopes(t *testing.T) { + // 混合逗号和空格分隔的 scopes + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + Scopes: "https://www.googleapis.com/auth/cloud-platform, https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile", + }, "code_assist") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + parts := strings.Fields(cfg.Scopes) + if len(parts) != 3 { + t.Errorf("归一化后应有 3 个 scope,实际: %d,scopes: %q", len(parts), cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_WhitespaceTriming(t *testing.T) { + // 输入中的前后空白应被清理 + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: " custom-id ", + ClientSecret: " custom-secret ", + Scopes: " https://www.googleapis.com/auth/cloud-platform ", + }, "code_assist") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if cfg.ClientID != "custom-id" { + t.Errorf("ClientID 应去除前后空白,实际: %q", cfg.ClientID) + } + if cfg.ClientSecret != "custom-secret" { + t.Errorf("ClientSecret 应去除前后空白,实际: %q", cfg.ClientSecret) + } + if cfg.Scopes != "https://www.googleapis.com/auth/cloud-platform" { + t.Errorf("Scopes 应去除前后空白,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_NoEnvSecret(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "") + + cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "code_assist") + if err != nil { + t.Fatalf("不设置环境变量时应回退到内置 secret,实际报错: %v", err) + } + if strings.TrimSpace(cfg.ClientSecret) == "" { + t.Error("ClientSecret 不应为空") + } + if cfg.ClientID != GeminiCLIOAuthClientID { + t.Errorf("ClientID 应回退为内置客户端 ID,实际: %q", cfg.ClientID) + } +} + +func TestEffectiveOAuthConfig_AIStudio_BuiltinClient_CustomScopes(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // ai_studio + 内置客户端 + 自定义 scopes -> 应过滤受限 scopes + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever", + }, "ai_studio") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + // 内置客户端应过滤 generative-language.retriever + if strings.Contains(cfg.Scopes, "generative-language") { + t.Errorf("ai_studio + 内置客户端应过滤受限 scopes,实际: %q", cfg.Scopes) + } + if !strings.Contains(cfg.Scopes, "cloud-platform") { + t.Errorf("应保留 cloud-platform scope,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_UnknownOAuthType_DefaultScopes(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // 未知的 oauthType 应回退到默认的 code_assist scopes + cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "unknown_type") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if cfg.Scopes != DefaultCodeAssistScopes { + t.Errorf("未知 oauthType 应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_EmptyOAuthType_DefaultScopes(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // 空的 oauthType 应走 default 分支,使用 DefaultCodeAssistScopes + cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if cfg.Scopes != DefaultCodeAssistScopes { + t.Errorf("空 oauthType 应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_CustomClient_NoScopeFiltering(t *testing.T) { + // 自定义客户端 + google_one + 包含受限 scopes -> 不应被过滤(因为不是内置客户端) + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly", + }, "google_one") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + // 自定义客户端不应过滤任何 scope + if !strings.Contains(cfg.Scopes, "generative-language.retriever") { + t.Errorf("自定义客户端不应过滤 generative-language.retriever,实际: %q", cfg.Scopes) + } + if !strings.Contains(cfg.Scopes, "drive.readonly") { + t.Errorf("自定义客户端不应过滤 drive.readonly,实际: %q", cfg.Scopes) } } diff --git a/backend/internal/pkg/httpclient/pool.go b/backend/internal/pkg/httpclient/pool.go index 76b7aa91..32e4bc5b 100644 --- a/backend/internal/pkg/httpclient/pool.go +++ b/backend/internal/pkg/httpclient/pool.go @@ -18,11 +18,11 @@ package httpclient import ( "fmt" "net/http" - "net/url" "strings" "sync" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" ) @@ -32,6 +32,7 @@ const ( defaultMaxIdleConns = 100 // 最大空闲连接数 defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数 defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时) + validatedHostTTL = 30 * time.Second // DNS Rebinding 校验缓存 TTL ) // Options 定义共享 HTTP 客户端的构建参数 @@ -40,7 +41,6 @@ type Options struct { Timeout time.Duration // 请求总超时时间 ResponseHeaderTimeout time.Duration // 等待响应头超时时间 InsecureSkipVerify bool // 是否跳过 TLS 证书验证(已禁用,不允许设置为 true) - ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退 ValidateResolvedIP bool // 是否校验解析后的 IP(防止 DNS Rebinding) AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用) @@ -53,6 +53,9 @@ type Options struct { // sharedClients 存储按配置参数缓存的 http.Client 实例 var sharedClients sync.Map +// 允许测试替换校验函数,生产默认指向真实实现。 +var validateResolvedIP = urlvalidator.ValidateResolvedIP + // GetClient 返回共享的 HTTP 客户端实例 // 性能优化:相同配置复用同一客户端,避免重复创建 Transport // 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险 @@ -84,7 +87,7 @@ func buildClient(opts Options) (*http.Client, error) { var rt http.RoundTripper = transport if opts.ValidateResolvedIP && !opts.AllowPrivateHosts { - rt = &validatedTransport{base: transport} + rt = newValidatedTransport(transport) } return &http.Client{ Transport: rt, @@ -116,15 +119,13 @@ func buildTransport(opts Options) (*http.Transport, error) { return nil, fmt.Errorf("insecure_skip_verify is not allowed; install a trusted certificate instead") } - proxyURL := strings.TrimSpace(opts.ProxyURL) - if proxyURL == "" { - return transport, nil - } - - parsed, err := url.Parse(proxyURL) + _, parsed, err := proxyurl.Parse(opts.ProxyURL) if err != nil { return nil, err } + if parsed == nil { + return transport, nil + } if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil { return nil, err @@ -134,12 +135,11 @@ func buildTransport(opts Options) (*http.Transport, error) { } func buildClientKey(opts Options) string { - return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%t|%d|%d|%d", + return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%d|%d|%d", strings.TrimSpace(opts.ProxyURL), opts.Timeout.String(), opts.ResponseHeaderTimeout.String(), opts.InsecureSkipVerify, - opts.ProxyStrict, opts.ValidateResolvedIP, opts.AllowPrivateHosts, opts.MaxIdleConns, @@ -149,17 +149,56 @@ func buildClientKey(opts Options) string { } type validatedTransport struct { - base http.RoundTripper + base http.RoundTripper + validatedHosts sync.Map // map[string]time.Time, value 为过期时间 + now func() time.Time +} + +func newValidatedTransport(base http.RoundTripper) *validatedTransport { + return &validatedTransport{ + base: base, + now: time.Now, + } +} + +func (t *validatedTransport) isValidatedHost(host string, now time.Time) bool { + if t == nil { + return false + } + raw, ok := t.validatedHosts.Load(host) + if !ok { + return false + } + expireAt, ok := raw.(time.Time) + if !ok { + t.validatedHosts.Delete(host) + return false + } + if now.Before(expireAt) { + return true + } + t.validatedHosts.Delete(host) + return false } func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) { if req != nil && req.URL != nil { - host := strings.TrimSpace(req.URL.Hostname()) + host := strings.ToLower(strings.TrimSpace(req.URL.Hostname())) if host != "" { - if err := urlvalidator.ValidateResolvedIP(host); err != nil { - return nil, err + now := time.Now() + if t != nil && t.now != nil { + now = t.now() + } + if !t.isValidatedHost(host, now) { + if err := validateResolvedIP(host); err != nil { + return nil, err + } + t.validatedHosts.Store(host, now.Add(validatedHostTTL)) } } } + if t == nil || t.base == nil { + return nil, fmt.Errorf("validated transport base is nil") + } return t.base.RoundTrip(req) } diff --git a/backend/internal/pkg/httpclient/pool_test.go b/backend/internal/pkg/httpclient/pool_test.go new file mode 100644 index 00000000..f945758a --- /dev/null +++ b/backend/internal/pkg/httpclient/pool_test.go @@ -0,0 +1,115 @@ +package httpclient + +import ( + "errors" + "io" + "net/http" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func TestValidatedTransport_CacheHostValidation(t *testing.T) { + originalValidate := validateResolvedIP + defer func() { validateResolvedIP = originalValidate }() + + var validateCalls int32 + validateResolvedIP = func(host string) error { + atomic.AddInt32(&validateCalls, 1) + require.Equal(t, "api.openai.com", host) + return nil + } + + var baseCalls int32 + base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { + atomic.AddInt32(&baseCalls, 1) + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{}`)), + Header: make(http.Header), + }, nil + }) + + now := time.Unix(1730000000, 0) + transport := newValidatedTransport(base) + transport.now = func() time.Time { return now } + + req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.NoError(t, err) + _, err = transport.RoundTrip(req) + require.NoError(t, err) + + require.Equal(t, int32(1), atomic.LoadInt32(&validateCalls)) + require.Equal(t, int32(2), atomic.LoadInt32(&baseCalls)) +} + +func TestValidatedTransport_ExpiredCacheTriggersRevalidation(t *testing.T) { + originalValidate := validateResolvedIP + defer func() { validateResolvedIP = originalValidate }() + + var validateCalls int32 + validateResolvedIP = func(_ string) error { + atomic.AddInt32(&validateCalls, 1) + return nil + } + + base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{}`)), + Header: make(http.Header), + }, nil + }) + + now := time.Unix(1730001000, 0) + transport := newValidatedTransport(base) + transport.now = func() time.Time { return now } + + req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.NoError(t, err) + + now = now.Add(validatedHostTTL + time.Second) + _, err = transport.RoundTrip(req) + require.NoError(t, err) + + require.Equal(t, int32(2), atomic.LoadInt32(&validateCalls)) +} + +func TestValidatedTransport_ValidationErrorStopsRoundTrip(t *testing.T) { + originalValidate := validateResolvedIP + defer func() { validateResolvedIP = originalValidate }() + + expectedErr := errors.New("dns rebinding rejected") + validateResolvedIP = func(_ string) error { + return expectedErr + } + + var baseCalls int32 + base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { + atomic.AddInt32(&baseCalls, 1) + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(`{}`))}, nil + }) + + transport := newValidatedTransport(base) + req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.ErrorIs(t, err, expectedErr) + require.Equal(t, int32(0), atomic.LoadInt32(&baseCalls)) +} diff --git a/backend/internal/pkg/httputil/body.go b/backend/internal/pkg/httputil/body.go new file mode 100644 index 00000000..69e99dc5 --- /dev/null +++ b/backend/internal/pkg/httputil/body.go @@ -0,0 +1,37 @@ +package httputil + +import ( + "bytes" + "io" + "net/http" +) + +const ( + requestBodyReadInitCap = 512 + requestBodyReadMaxInitCap = 1 << 20 +) + +// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length. +func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) { + if req == nil || req.Body == nil { + return nil, nil + } + + capHint := requestBodyReadInitCap + if req.ContentLength > 0 { + switch { + case req.ContentLength < int64(requestBodyReadInitCap): + capHint = requestBodyReadInitCap + case req.ContentLength > int64(requestBodyReadMaxInitCap): + capHint = requestBodyReadMaxInitCap + default: + capHint = int(req.ContentLength) + } + } + + buf := bytes.NewBuffer(make([]byte, 0, capHint)) + if _, err := io.Copy(buf, req.Body); err != nil { + return nil, err + } + return buf.Bytes(), nil +} diff --git a/backend/internal/pkg/ip/ip.go b/backend/internal/pkg/ip/ip.go index 97109c0c..f6f77c86 100644 --- a/backend/internal/pkg/ip/ip.go +++ b/backend/internal/pkg/ip/ip.go @@ -44,6 +44,16 @@ func GetClientIP(c *gin.Context) string { return normalizeIP(c.ClientIP()) } +// GetTrustedClientIP 从 Gin 的可信代理解析链提取客户端 IP。 +// 该方法依赖 gin.Engine.SetTrustedProxies 配置,不会优先直接信任原始转发头值。 +// 适用于 ACL / 风控等安全敏感场景。 +func GetTrustedClientIP(c *gin.Context) string { + if c == nil { + return "" + } + return normalizeIP(c.ClientIP()) +} + // normalizeIP 规范化 IP 地址,去除端口号和空格。 func normalizeIP(ip string) string { ip = strings.TrimSpace(ip) @@ -54,29 +64,89 @@ func normalizeIP(ip string) string { return ip } -// isPrivateIP 检查 IP 是否为私有地址。 -func isPrivateIP(ipStr string) bool { - ip := net.ParseIP(ipStr) - if ip == nil { - return false - } +// privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析 +var privateNets []*net.IPNet - // 私有 IP 范围 - privateBlocks := []string{ +// CompiledIPRules 表示预编译的 IP 匹配规则。 +// PatternCount 记录原始规则数量,用于保留“规则存在但全无效”时的行为语义。 +type CompiledIPRules struct { + CIDRs []*net.IPNet + IPs []net.IP + PatternCount int +} + +func init() { + for _, cidr := range []string{ "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "127.0.0.0/8", "::1/128", "fc00::/7", - } - - for _, block := range privateBlocks { - _, cidr, err := net.ParseCIDR(block) + } { + _, block, err := net.ParseCIDR(cidr) if err != nil { + panic("invalid CIDR: " + cidr) + } + privateNets = append(privateNets, block) + } +} + +// CompileIPRules 将 IP/CIDR 字符串规则预编译为可复用结构。 +// 非法规则会被忽略,但 PatternCount 会保留原始规则条数。 +func CompileIPRules(patterns []string) *CompiledIPRules { + compiled := &CompiledIPRules{ + CIDRs: make([]*net.IPNet, 0, len(patterns)), + IPs: make([]net.IP, 0, len(patterns)), + PatternCount: len(patterns), + } + for _, pattern := range patterns { + normalized := strings.TrimSpace(pattern) + if normalized == "" { continue } - if cidr.Contains(ip) { + if strings.Contains(normalized, "/") { + _, cidr, err := net.ParseCIDR(normalized) + if err != nil || cidr == nil { + continue + } + compiled.CIDRs = append(compiled.CIDRs, cidr) + continue + } + parsedIP := net.ParseIP(normalized) + if parsedIP == nil { + continue + } + compiled.IPs = append(compiled.IPs, parsedIP) + } + return compiled +} + +func matchesCompiledRules(parsedIP net.IP, rules *CompiledIPRules) bool { + if parsedIP == nil || rules == nil { + return false + } + for _, cidr := range rules.CIDRs { + if cidr.Contains(parsedIP) { + return true + } + } + for _, ruleIP := range rules.IPs { + if parsedIP.Equal(ruleIP) { + return true + } + } + return false +} + +// isPrivateIP 检查 IP 是否为私有地址。 +func isPrivateIP(ipStr string) bool { + ip := net.ParseIP(ipStr) + if ip == nil { + return false + } + for _, block := range privateNets { + if block.Contains(ip) { return true } } @@ -127,19 +197,32 @@ func MatchesAnyPattern(clientIP string, patterns []string) bool { // 2. 如果白名单不为空,IP 必须在白名单中 // 3. 如果白名单为空,允许访问(除非被黑名单拒绝) func CheckIPRestriction(clientIP string, whitelist, blacklist []string) (bool, string) { + return CheckIPRestrictionWithCompiledRules( + clientIP, + CompileIPRules(whitelist), + CompileIPRules(blacklist), + ) +} + +// CheckIPRestrictionWithCompiledRules 使用预编译规则检查 IP 是否允许访问。 +func CheckIPRestrictionWithCompiledRules(clientIP string, whitelist, blacklist *CompiledIPRules) (bool, string) { // 规范化 IP clientIP = normalizeIP(clientIP) if clientIP == "" { return false, "access denied" } + parsedIP := net.ParseIP(clientIP) + if parsedIP == nil { + return false, "access denied" + } // 1. 检查黑名单 - if len(blacklist) > 0 && MatchesAnyPattern(clientIP, blacklist) { + if blacklist != nil && blacklist.PatternCount > 0 && matchesCompiledRules(parsedIP, blacklist) { return false, "access denied" } // 2. 检查白名单(如果设置了白名单,IP 必须在其中) - if len(whitelist) > 0 && !MatchesAnyPattern(clientIP, whitelist) { + if whitelist != nil && whitelist.PatternCount > 0 && !matchesCompiledRules(parsedIP, whitelist) { return false, "access denied" } diff --git a/backend/internal/pkg/ip/ip_test.go b/backend/internal/pkg/ip/ip_test.go new file mode 100644 index 00000000..403b2d59 --- /dev/null +++ b/backend/internal/pkg/ip/ip_test.go @@ -0,0 +1,96 @@ +//go:build unit + +package ip + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestIsPrivateIP(t *testing.T) { + tests := []struct { + name string + ip string + expected bool + }{ + // 私有 IPv4 + {"10.x 私有地址", "10.0.0.1", true}, + {"10.x 私有地址段末", "10.255.255.255", true}, + {"172.16.x 私有地址", "172.16.0.1", true}, + {"172.31.x 私有地址", "172.31.255.255", true}, + {"192.168.x 私有地址", "192.168.1.1", true}, + {"127.0.0.1 本地回环", "127.0.0.1", true}, + {"127.x 回环段", "127.255.255.255", true}, + + // 公网 IPv4 + {"8.8.8.8 公网 DNS", "8.8.8.8", false}, + {"1.1.1.1 公网", "1.1.1.1", false}, + {"172.15.255.255 非私有", "172.15.255.255", false}, + {"172.32.0.0 非私有", "172.32.0.0", false}, + {"11.0.0.1 公网", "11.0.0.1", false}, + + // IPv6 + {"::1 IPv6 回环", "::1", true}, + {"fc00:: IPv6 私有", "fc00::1", true}, + {"fd00:: IPv6 私有", "fd00::1", true}, + {"2001:db8::1 IPv6 公网", "2001:db8::1", false}, + + // 无效输入 + {"空字符串", "", false}, + {"非法字符串", "not-an-ip", false}, + {"不完整 IP", "192.168", false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := isPrivateIP(tc.ip) + require.Equal(t, tc.expected, got, "isPrivateIP(%q)", tc.ip) + }) + } +} + +func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + require.NoError(t, r.SetTrustedProxies(nil)) + + r.GET("/t", func(c *gin.Context) { + c.String(200, GetTrustedClientIP(c)) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/t", nil) + req.RemoteAddr = "9.9.9.9:12345" + req.Header.Set("X-Forwarded-For", "1.2.3.4") + req.Header.Set("X-Real-IP", "1.2.3.4") + req.Header.Set("CF-Connecting-IP", "1.2.3.4") + r.ServeHTTP(w, req) + + require.Equal(t, 200, w.Code) + require.Equal(t, "9.9.9.9", w.Body.String()) +} + +func TestCheckIPRestrictionWithCompiledRules(t *testing.T) { + whitelist := CompileIPRules([]string{"10.0.0.0/8", "192.168.1.2"}) + blacklist := CompileIPRules([]string{"10.1.1.1"}) + + allowed, reason := CheckIPRestrictionWithCompiledRules("10.2.3.4", whitelist, blacklist) + require.True(t, allowed) + require.Equal(t, "", reason) + + allowed, reason = CheckIPRestrictionWithCompiledRules("10.1.1.1", whitelist, blacklist) + require.False(t, allowed) + require.Equal(t, "access denied", reason) +} + +func TestCheckIPRestrictionWithCompiledRules_InvalidWhitelistStillDenies(t *testing.T) { + // 与旧实现保持一致:白名单有配置但全无效时,最终应拒绝访问。 + invalidWhitelist := CompileIPRules([]string{"not-a-valid-pattern"}) + allowed, reason := CheckIPRestrictionWithCompiledRules("8.8.8.8", invalidWhitelist, nil) + require.False(t, allowed) + require.Equal(t, "access denied", reason) +} diff --git a/backend/internal/pkg/logger/config_adapter.go b/backend/internal/pkg/logger/config_adapter.go new file mode 100644 index 00000000..c34e448b --- /dev/null +++ b/backend/internal/pkg/logger/config_adapter.go @@ -0,0 +1,31 @@ +package logger + +import "github.com/Wei-Shaw/sub2api/internal/config" + +func OptionsFromConfig(cfg config.LogConfig) InitOptions { + return InitOptions{ + Level: cfg.Level, + Format: cfg.Format, + ServiceName: cfg.ServiceName, + Environment: cfg.Environment, + Caller: cfg.Caller, + StacktraceLevel: cfg.StacktraceLevel, + Output: OutputOptions{ + ToStdout: cfg.Output.ToStdout, + ToFile: cfg.Output.ToFile, + FilePath: cfg.Output.FilePath, + }, + Rotation: RotationOptions{ + MaxSizeMB: cfg.Rotation.MaxSizeMB, + MaxBackups: cfg.Rotation.MaxBackups, + MaxAgeDays: cfg.Rotation.MaxAgeDays, + Compress: cfg.Rotation.Compress, + LocalTime: cfg.Rotation.LocalTime, + }, + Sampling: SamplingOptions{ + Enabled: cfg.Sampling.Enabled, + Initial: cfg.Sampling.Initial, + Thereafter: cfg.Sampling.Thereafter, + }, + } +} diff --git a/backend/internal/pkg/logger/logger.go b/backend/internal/pkg/logger/logger.go new file mode 100644 index 00000000..3fca706e --- /dev/null +++ b/backend/internal/pkg/logger/logger.go @@ -0,0 +1,530 @@ +package logger + +import ( + "context" + "fmt" + "io" + "log" + "log/slog" + "os" + "path/filepath" + "strings" + "sync" + "sync/atomic" + "time" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "gopkg.in/natefinch/lumberjack.v2" +) + +type Level = zapcore.Level + +const ( + LevelDebug = zapcore.DebugLevel + LevelInfo = zapcore.InfoLevel + LevelWarn = zapcore.WarnLevel + LevelError = zapcore.ErrorLevel + LevelFatal = zapcore.FatalLevel +) + +type Sink interface { + WriteLogEvent(event *LogEvent) +} + +type LogEvent struct { + Time time.Time + Level string + Component string + Message string + LoggerName string + Fields map[string]any +} + +var ( + mu sync.RWMutex + global atomic.Pointer[zap.Logger] + sugar atomic.Pointer[zap.SugaredLogger] + atomicLevel zap.AtomicLevel + initOptions InitOptions + currentSink atomic.Value // sinkState + stdLogUndo func() + bootstrapOnce sync.Once +) + +type sinkState struct { + sink Sink +} + +func InitBootstrap() { + bootstrapOnce.Do(func() { + if err := Init(bootstrapOptions()); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "logger bootstrap init failed: %v\n", err) + } + }) +} + +func Init(options InitOptions) error { + mu.Lock() + defer mu.Unlock() + return initLocked(options) +} + +func initLocked(options InitOptions) error { + normalized := options.normalized() + zl, al, err := buildLogger(normalized) + if err != nil { + return err + } + + prev := global.Load() + global.Store(zl) + sugar.Store(zl.Sugar()) + atomicLevel = al + initOptions = normalized + + bridgeSlogLocked() + bridgeStdLogLocked() + + if prev != nil { + _ = prev.Sync() + } + return nil +} + +func Reconfigure(mutator func(*InitOptions) error) error { + mu.Lock() + defer mu.Unlock() + next := initOptions + if mutator != nil { + if err := mutator(&next); err != nil { + return err + } + } + return initLocked(next) +} + +func SetLevel(level string) error { + lv, ok := parseLevel(level) + if !ok { + return fmt.Errorf("invalid log level: %s", level) + } + + mu.Lock() + defer mu.Unlock() + atomicLevel.SetLevel(lv) + initOptions.Level = strings.ToLower(strings.TrimSpace(level)) + return nil +} + +func CurrentLevel() string { + mu.RLock() + defer mu.RUnlock() + if global.Load() == nil { + return "info" + } + return atomicLevel.Level().String() +} + +func SetSink(sink Sink) { + currentSink.Store(sinkState{sink: sink}) +} + +func loadSink() Sink { + v := currentSink.Load() + if v == nil { + return nil + } + state, ok := v.(sinkState) + if !ok { + return nil + } + return state.sink +} + +// WriteSinkEvent 直接写入日志 sink,不经过全局日志级别门控。 +// 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。 +func WriteSinkEvent(level, component, message string, fields map[string]any) { + sink := loadSink() + if sink == nil { + return + } + + level = strings.ToLower(strings.TrimSpace(level)) + if level == "" { + level = "info" + } + component = strings.TrimSpace(component) + message = strings.TrimSpace(message) + if message == "" { + return + } + + eventFields := make(map[string]any, len(fields)+1) + for k, v := range fields { + eventFields[k] = v + } + if component != "" { + if _, ok := eventFields["component"]; !ok { + eventFields["component"] = component + } + } + + sink.WriteLogEvent(&LogEvent{ + Time: time.Now(), + Level: level, + Component: component, + Message: message, + LoggerName: component, + Fields: eventFields, + }) +} + +func L() *zap.Logger { + if l := global.Load(); l != nil { + return l + } + return zap.NewNop() +} + +func S() *zap.SugaredLogger { + if s := sugar.Load(); s != nil { + return s + } + return zap.NewNop().Sugar() +} + +func With(fields ...zap.Field) *zap.Logger { + return L().With(fields...) +} + +func Sync() { + l := global.Load() + if l != nil { + _ = l.Sync() + } +} + +func bridgeStdLogLocked() { + if stdLogUndo != nil { + stdLogUndo() + stdLogUndo = nil + } + + prevFlags := log.Flags() + prevPrefix := log.Prefix() + prevWriter := log.Writer() + + log.SetFlags(0) + log.SetPrefix("") + base := global.Load() + if base == nil { + base = zap.NewNop() + } + log.SetOutput(newStdLogBridge(base.Named("stdlog"))) + + stdLogUndo = func() { + log.SetOutput(prevWriter) + log.SetFlags(prevFlags) + log.SetPrefix(prevPrefix) + } +} + +func bridgeSlogLocked() { + base := global.Load() + if base == nil { + base = zap.NewNop() + } + slog.SetDefault(slog.New(newSlogZapHandler(base.Named("slog")))) +} + +func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) { + level, _ := parseLevel(options.Level) + atomic := zap.NewAtomicLevelAt(level) + + encoderCfg := zapcore.EncoderConfig{ + TimeKey: "time", + LevelKey: "level", + NameKey: "logger", + CallerKey: "caller", + MessageKey: "msg", + StacktraceKey: "stacktrace", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: zapcore.CapitalLevelEncoder, + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeDuration: zapcore.MillisDurationEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + } + + var enc zapcore.Encoder + if options.Format == "console" { + enc = zapcore.NewConsoleEncoder(encoderCfg) + } else { + enc = zapcore.NewJSONEncoder(encoderCfg) + } + + sinkCore := newSinkCore() + cores := make([]zapcore.Core, 0, 3) + + if options.Output.ToStdout { + infoPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { + return lvl >= atomic.Level() && lvl < zapcore.WarnLevel + }) + errPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { + return lvl >= atomic.Level() && lvl >= zapcore.WarnLevel + }) + cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stdout), infoPriority)) + cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stderr), errPriority)) + } + + if options.Output.ToFile { + fileCore, filePath, fileErr := buildFileCore(enc, atomic, options) + if fileErr != nil { + _, _ = fmt.Fprintf(os.Stderr, "time=%s level=WARN msg=\"日志文件输出初始化失败,降级为仅标准输出\" path=%s err=%v\n", + time.Now().Format(time.RFC3339Nano), + filePath, + fileErr, + ) + } else { + cores = append(cores, fileCore) + } + } + + if len(cores) == 0 { + cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stdout), atomic)) + } + + core := zapcore.NewTee(cores...) + if options.Sampling.Enabled { + core = zapcore.NewSamplerWithOptions(core, samplingTick(), options.Sampling.Initial, options.Sampling.Thereafter) + } + core = sinkCore.Wrap(core) + + stacktraceLevel, _ := parseStacktraceLevel(options.StacktraceLevel) + zapOpts := make([]zap.Option, 0, 5) + if options.Caller { + zapOpts = append(zapOpts, zap.AddCaller()) + } + if stacktraceLevel <= zapcore.FatalLevel { + zapOpts = append(zapOpts, zap.AddStacktrace(stacktraceLevel)) + } + + logger := zap.New(core, zapOpts...).With( + zap.String("service", options.ServiceName), + zap.String("env", options.Environment), + ) + return logger, atomic, nil +} + +func buildFileCore(enc zapcore.Encoder, atomic zap.AtomicLevel, options InitOptions) (zapcore.Core, string, error) { + filePath := options.Output.FilePath + if strings.TrimSpace(filePath) == "" { + filePath = resolveLogFilePath("") + } + + dir := filepath.Dir(filePath) + if err := os.MkdirAll(dir, 0o755); err != nil { + return nil, filePath, err + } + lj := &lumberjack.Logger{ + Filename: filePath, + MaxSize: options.Rotation.MaxSizeMB, + MaxBackups: options.Rotation.MaxBackups, + MaxAge: options.Rotation.MaxAgeDays, + Compress: options.Rotation.Compress, + LocalTime: options.Rotation.LocalTime, + } + return zapcore.NewCore(enc, zapcore.AddSync(lj), atomic), filePath, nil +} + +type sinkCore struct { + core zapcore.Core + fields []zapcore.Field +} + +func newSinkCore() *sinkCore { + return &sinkCore{} +} + +func (s *sinkCore) Wrap(core zapcore.Core) zapcore.Core { + cp := *s + cp.core = core + return &cp +} + +func (s *sinkCore) Enabled(level zapcore.Level) bool { + return s.core.Enabled(level) +} + +func (s *sinkCore) With(fields []zapcore.Field) zapcore.Core { + nextFields := append([]zapcore.Field{}, s.fields...) + nextFields = append(nextFields, fields...) + return &sinkCore{ + core: s.core.With(fields), + fields: nextFields, + } +} + +func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry { + // Delegate to inner core (tee) so each sub-core's level enabler is respected. + // Then add ourselves for sink forwarding only. + ce = s.core.Check(entry, ce) + if ce != nil { + ce = ce.AddCore(entry, s) + } + return ce +} + +func (s *sinkCore) Write(entry zapcore.Entry, fields []zapcore.Field) error { + // Only handle sink forwarding — the inner cores write via their own + // Write methods (added to CheckedEntry by s.core.Check above). + sink := loadSink() + if sink == nil { + return nil + } + + enc := zapcore.NewMapObjectEncoder() + for _, f := range s.fields { + f.AddTo(enc) + } + for _, f := range fields { + f.AddTo(enc) + } + + event := &LogEvent{ + Time: entry.Time, + Level: strings.ToLower(entry.Level.String()), + Component: entry.LoggerName, + Message: entry.Message, + LoggerName: entry.LoggerName, + Fields: enc.Fields, + } + sink.WriteLogEvent(event) + return nil +} + +func (s *sinkCore) Sync() error { + return s.core.Sync() +} + +type stdLogBridge struct { + logger *zap.Logger +} + +func newStdLogBridge(l *zap.Logger) io.Writer { + if l == nil { + l = zap.NewNop() + } + return &stdLogBridge{logger: l} +} + +func (b *stdLogBridge) Write(p []byte) (int, error) { + msg := normalizeStdLogMessage(string(p)) + if msg == "" { + return len(p), nil + } + + level := inferStdLogLevel(msg) + entry := b.logger.WithOptions(zap.AddCallerSkip(4)) + + switch level { + case LevelDebug: + entry.Debug(msg, zap.Bool("legacy_stdlog", true)) + case LevelWarn: + entry.Warn(msg, zap.Bool("legacy_stdlog", true)) + case LevelError, LevelFatal: + entry.Error(msg, zap.Bool("legacy_stdlog", true)) + default: + entry.Info(msg, zap.Bool("legacy_stdlog", true)) + } + return len(p), nil +} + +func normalizeStdLogMessage(raw string) string { + msg := strings.TrimSpace(strings.ReplaceAll(raw, "\n", " ")) + if msg == "" { + return "" + } + return strings.Join(strings.Fields(msg), " ") +} + +func inferStdLogLevel(msg string) Level { + lower := strings.ToLower(strings.TrimSpace(msg)) + if lower == "" { + return LevelInfo + } + + if strings.HasPrefix(lower, "[debug]") || strings.HasPrefix(lower, "debug:") { + return LevelDebug + } + if strings.HasPrefix(lower, "[warn]") || strings.HasPrefix(lower, "[warning]") || strings.HasPrefix(lower, "warn:") || strings.HasPrefix(lower, "warning:") { + return LevelWarn + } + if strings.HasPrefix(lower, "[error]") || strings.HasPrefix(lower, "error:") || strings.HasPrefix(lower, "fatal:") || strings.HasPrefix(lower, "panic:") { + return LevelError + } + + if strings.Contains(lower, " failed") || strings.Contains(lower, "error") || strings.Contains(lower, "panic") || strings.Contains(lower, "fatal") { + return LevelError + } + if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") { + return LevelWarn + } + return LevelInfo +} + +// LegacyPrintf 用于平滑迁移历史的 printf 风格日志到结构化 logger。 +func LegacyPrintf(component, format string, args ...any) { + msg := normalizeStdLogMessage(fmt.Sprintf(format, args...)) + if msg == "" { + return + } + + initialized := global.Load() != nil + if !initialized { + // 在日志系统未初始化前,回退到标准库 log,避免测试/工具链丢日志。 + log.Print(msg) + return + } + + l := L() + if component != "" { + l = l.With(zap.String("component", component)) + } + l = l.WithOptions(zap.AddCallerSkip(1)) + + switch inferStdLogLevel(msg) { + case LevelDebug: + l.Debug(msg, zap.Bool("legacy_printf", true)) + case LevelWarn: + l.Warn(msg, zap.Bool("legacy_printf", true)) + case LevelError, LevelFatal: + l.Error(msg, zap.Bool("legacy_printf", true)) + default: + l.Info(msg, zap.Bool("legacy_printf", true)) + } +} + +type contextKey string + +const loggerContextKey contextKey = "ctx_logger" + +func IntoContext(ctx context.Context, l *zap.Logger) context.Context { + if ctx == nil { + ctx = context.Background() + } + if l == nil { + l = L() + } + return context.WithValue(ctx, loggerContextKey, l) +} + +func FromContext(ctx context.Context) *zap.Logger { + if ctx == nil { + return L() + } + if l, ok := ctx.Value(loggerContextKey).(*zap.Logger); ok && l != nil { + return l + } + return L() +} diff --git a/backend/internal/pkg/logger/logger_test.go b/backend/internal/pkg/logger/logger_test.go new file mode 100644 index 00000000..74aae061 --- /dev/null +++ b/backend/internal/pkg/logger/logger_test.go @@ -0,0 +1,192 @@ +package logger + +import ( + "encoding/json" + "io" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestInit_DualOutput(t *testing.T) { + tmpDir := t.TempDir() + logPath := filepath.Join(tmpDir, "logs", "sub2api.log") + + origStdout := os.Stdout + origStderr := os.Stderr + stdoutR, stdoutW, err := os.Pipe() + if err != nil { + t.Fatalf("create stdout pipe: %v", err) + } + stderrR, stderrW, err := os.Pipe() + if err != nil { + t.Fatalf("create stderr pipe: %v", err) + } + os.Stdout = stdoutW + os.Stderr = stderrW + t.Cleanup(func() { + os.Stdout = origStdout + os.Stderr = origStderr + _ = stdoutR.Close() + _ = stderrR.Close() + _ = stdoutW.Close() + _ = stderrW.Close() + }) + + err = Init(InitOptions{ + Level: "debug", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: OutputOptions{ + ToStdout: true, + ToFile: true, + FilePath: logPath, + }, + Rotation: RotationOptions{ + MaxSizeMB: 10, + MaxBackups: 2, + MaxAgeDays: 1, + }, + Sampling: SamplingOptions{Enabled: false}, + }) + if err != nil { + t.Fatalf("Init() error: %v", err) + } + + L().Info("dual-output-info") + L().Warn("dual-output-warn") + Sync() + + _ = stdoutW.Close() + _ = stderrW.Close() + stdoutBytes, _ := io.ReadAll(stdoutR) + stderrBytes, _ := io.ReadAll(stderrR) + stdoutText := string(stdoutBytes) + stderrText := string(stderrBytes) + + if !strings.Contains(stdoutText, "dual-output-info") { + t.Fatalf("stdout missing info log: %s", stdoutText) + } + if !strings.Contains(stderrText, "dual-output-warn") { + t.Fatalf("stderr missing warn log: %s", stderrText) + } + + fileBytes, err := os.ReadFile(logPath) + if err != nil { + t.Fatalf("read log file: %v", err) + } + fileText := string(fileBytes) + if !strings.Contains(fileText, "dual-output-info") || !strings.Contains(fileText, "dual-output-warn") { + t.Fatalf("file missing logs: %s", fileText) + } +} + +func TestInit_FileOutputFailureDowngrade(t *testing.T) { + origStdout := os.Stdout + origStderr := os.Stderr + _, stdoutW, err := os.Pipe() + if err != nil { + t.Fatalf("create stdout pipe: %v", err) + } + stderrR, stderrW, err := os.Pipe() + if err != nil { + t.Fatalf("create stderr pipe: %v", err) + } + os.Stdout = stdoutW + os.Stderr = stderrW + t.Cleanup(func() { + os.Stdout = origStdout + os.Stderr = origStderr + _ = stdoutW.Close() + _ = stderrR.Close() + _ = stderrW.Close() + }) + + err = Init(InitOptions{ + Level: "info", + Format: "json", + Output: OutputOptions{ + ToStdout: true, + ToFile: true, + FilePath: filepath.Join(os.DevNull, "logs", "sub2api.log"), + }, + Rotation: RotationOptions{ + MaxSizeMB: 10, + MaxBackups: 1, + MaxAgeDays: 1, + }, + }) + if err != nil { + t.Fatalf("Init() should downgrade instead of failing, got: %v", err) + } + + _ = stderrW.Close() + stderrBytes, _ := io.ReadAll(stderrR) + if !strings.Contains(string(stderrBytes), "日志文件输出初始化失败") { + t.Fatalf("stderr should contain fallback warning, got: %s", string(stderrBytes)) + } +} + +func TestInit_CallerShouldPointToCallsite(t *testing.T) { + origStdout := os.Stdout + origStderr := os.Stderr + stdoutR, stdoutW, err := os.Pipe() + if err != nil { + t.Fatalf("create stdout pipe: %v", err) + } + _, stderrW, err := os.Pipe() + if err != nil { + t.Fatalf("create stderr pipe: %v", err) + } + os.Stdout = stdoutW + os.Stderr = stderrW + t.Cleanup(func() { + os.Stdout = origStdout + os.Stderr = origStderr + _ = stdoutR.Close() + _ = stdoutW.Close() + _ = stderrW.Close() + }) + + if err := Init(InitOptions{ + Level: "info", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Caller: true, + Output: OutputOptions{ + ToStdout: true, + ToFile: false, + }, + Sampling: SamplingOptions{Enabled: false}, + }); err != nil { + t.Fatalf("Init() error: %v", err) + } + + L().Info("caller-check") + Sync() + _ = stdoutW.Close() + logBytes, _ := io.ReadAll(stdoutR) + + var line string + for _, item := range strings.Split(string(logBytes), "\n") { + if strings.Contains(item, "caller-check") { + line = item + break + } + } + if line == "" { + t.Fatalf("log output missing caller-check: %s", string(logBytes)) + } + + var payload map[string]any + if err := json.Unmarshal([]byte(line), &payload); err != nil { + t.Fatalf("parse log json failed: %v, line=%s", err, line) + } + caller, _ := payload["caller"].(string) + if !strings.Contains(caller, "logger_test.go:") { + t.Fatalf("caller should point to this test file, got: %s", caller) + } +} diff --git a/backend/internal/pkg/logger/options.go b/backend/internal/pkg/logger/options.go new file mode 100644 index 00000000..efcd701c --- /dev/null +++ b/backend/internal/pkg/logger/options.go @@ -0,0 +1,161 @@ +package logger + +import ( + "os" + "path/filepath" + "strings" + "time" +) + +const ( + // DefaultContainerLogPath 为容器内默认日志文件路径。 + DefaultContainerLogPath = "/app/data/logs/sub2api.log" + defaultLogFilename = "sub2api.log" +) + +type InitOptions struct { + Level string + Format string + ServiceName string + Environment string + Caller bool + StacktraceLevel string + Output OutputOptions + Rotation RotationOptions + Sampling SamplingOptions +} + +type OutputOptions struct { + ToStdout bool + ToFile bool + FilePath string +} + +type RotationOptions struct { + MaxSizeMB int + MaxBackups int + MaxAgeDays int + Compress bool + LocalTime bool +} + +type SamplingOptions struct { + Enabled bool + Initial int + Thereafter int +} + +func (o InitOptions) normalized() InitOptions { + out := o + out.Level = strings.ToLower(strings.TrimSpace(out.Level)) + if out.Level == "" { + out.Level = "info" + } + out.Format = strings.ToLower(strings.TrimSpace(out.Format)) + if out.Format == "" { + out.Format = "console" + } + out.ServiceName = strings.TrimSpace(out.ServiceName) + if out.ServiceName == "" { + out.ServiceName = "sub2api" + } + out.Environment = strings.TrimSpace(out.Environment) + if out.Environment == "" { + out.Environment = "production" + } + out.StacktraceLevel = strings.ToLower(strings.TrimSpace(out.StacktraceLevel)) + if out.StacktraceLevel == "" { + out.StacktraceLevel = "error" + } + if !out.Output.ToStdout && !out.Output.ToFile { + out.Output.ToStdout = true + } + out.Output.FilePath = resolveLogFilePath(out.Output.FilePath) + if out.Rotation.MaxSizeMB <= 0 { + out.Rotation.MaxSizeMB = 100 + } + if out.Rotation.MaxBackups < 0 { + out.Rotation.MaxBackups = 10 + } + if out.Rotation.MaxAgeDays < 0 { + out.Rotation.MaxAgeDays = 7 + } + if out.Sampling.Enabled { + if out.Sampling.Initial <= 0 { + out.Sampling.Initial = 100 + } + if out.Sampling.Thereafter <= 0 { + out.Sampling.Thereafter = 100 + } + } + return out +} + +func resolveLogFilePath(explicit string) string { + explicit = strings.TrimSpace(explicit) + if explicit != "" { + return explicit + } + dataDir := strings.TrimSpace(os.Getenv("DATA_DIR")) + if dataDir != "" { + return filepath.Join(dataDir, "logs", defaultLogFilename) + } + return DefaultContainerLogPath +} + +func bootstrapOptions() InitOptions { + return InitOptions{ + Level: "info", + Format: "console", + ServiceName: "sub2api", + Environment: "bootstrap", + Output: OutputOptions{ + ToStdout: true, + ToFile: false, + }, + Rotation: RotationOptions{ + MaxSizeMB: 100, + MaxBackups: 10, + MaxAgeDays: 7, + Compress: true, + LocalTime: true, + }, + Sampling: SamplingOptions{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + } +} + +func parseLevel(level string) (Level, bool) { + switch strings.ToLower(strings.TrimSpace(level)) { + case "debug": + return LevelDebug, true + case "info": + return LevelInfo, true + case "warn": + return LevelWarn, true + case "error": + return LevelError, true + default: + return LevelInfo, false + } +} + +func parseStacktraceLevel(level string) (Level, bool) { + switch strings.ToLower(strings.TrimSpace(level)) { + case "none": + return LevelFatal + 1, true + case "error": + return LevelError, true + case "fatal": + return LevelFatal, true + default: + return LevelError, false + } +} + +func samplingTick() time.Duration { + return time.Second +} diff --git a/backend/internal/pkg/logger/options_test.go b/backend/internal/pkg/logger/options_test.go new file mode 100644 index 00000000..10d50d72 --- /dev/null +++ b/backend/internal/pkg/logger/options_test.go @@ -0,0 +1,102 @@ +package logger + +import ( + "os" + "path/filepath" + "testing" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +func TestResolveLogFilePath_Default(t *testing.T) { + t.Setenv("DATA_DIR", "") + got := resolveLogFilePath("") + if got != DefaultContainerLogPath { + t.Fatalf("resolveLogFilePath() = %q, want %q", got, DefaultContainerLogPath) + } +} + +func TestResolveLogFilePath_WithDataDir(t *testing.T) { + t.Setenv("DATA_DIR", "/tmp/sub2api-data") + got := resolveLogFilePath("") + want := filepath.Join("/tmp/sub2api-data", "logs", "sub2api.log") + if got != want { + t.Fatalf("resolveLogFilePath() = %q, want %q", got, want) + } +} + +func TestResolveLogFilePath_ExplicitPath(t *testing.T) { + t.Setenv("DATA_DIR", "/tmp/ignore") + got := resolveLogFilePath("/var/log/custom.log") + if got != "/var/log/custom.log" { + t.Fatalf("resolveLogFilePath() = %q, want explicit path", got) + } +} + +func TestNormalizedOptions_InvalidFallback(t *testing.T) { + t.Setenv("DATA_DIR", "") + opts := InitOptions{ + Level: "TRACE", + Format: "TEXT", + ServiceName: "", + Environment: "", + StacktraceLevel: "panic", + Output: OutputOptions{ + ToStdout: false, + ToFile: false, + }, + Rotation: RotationOptions{ + MaxSizeMB: 0, + MaxBackups: -1, + MaxAgeDays: -1, + }, + Sampling: SamplingOptions{ + Enabled: true, + Initial: 0, + Thereafter: 0, + }, + } + out := opts.normalized() + if out.Level != "trace" { + // normalized 仅做 trim/lower,不做校验;校验在 config 层。 + t.Fatalf("normalized level should preserve value for upstream validation, got %q", out.Level) + } + if !out.Output.ToStdout { + t.Fatalf("normalized output should fallback to stdout") + } + if out.Output.FilePath != DefaultContainerLogPath { + t.Fatalf("normalized file path = %q", out.Output.FilePath) + } + if out.Rotation.MaxSizeMB != 100 { + t.Fatalf("normalized max_size_mb = %d", out.Rotation.MaxSizeMB) + } + if out.Rotation.MaxBackups != 10 { + t.Fatalf("normalized max_backups = %d", out.Rotation.MaxBackups) + } + if out.Rotation.MaxAgeDays != 7 { + t.Fatalf("normalized max_age_days = %d", out.Rotation.MaxAgeDays) + } + if out.Sampling.Initial != 100 || out.Sampling.Thereafter != 100 { + t.Fatalf("normalized sampling defaults invalid: %+v", out.Sampling) + } +} + +func TestBuildFileCore_InvalidPathFallback(t *testing.T) { + t.Setenv("DATA_DIR", "") + opts := bootstrapOptions() + opts.Output.ToFile = true + opts.Output.FilePath = filepath.Join(os.DevNull, "logs", "sub2api.log") + encoderCfg := zapcore.EncoderConfig{ + TimeKey: "time", + LevelKey: "level", + MessageKey: "msg", + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeLevel: zapcore.CapitalLevelEncoder, + } + encoder := zapcore.NewJSONEncoder(encoderCfg) + _, _, err := buildFileCore(encoder, zap.NewAtomicLevel(), opts) + if err == nil { + t.Fatalf("buildFileCore() expected error for invalid path") + } +} diff --git a/backend/internal/pkg/logger/slog_handler.go b/backend/internal/pkg/logger/slog_handler.go new file mode 100644 index 00000000..602ca1e0 --- /dev/null +++ b/backend/internal/pkg/logger/slog_handler.go @@ -0,0 +1,131 @@ +package logger + +import ( + "context" + "log/slog" + "strings" + "time" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type slogZapHandler struct { + logger *zap.Logger + attrs []slog.Attr + groups []string +} + +func newSlogZapHandler(logger *zap.Logger) slog.Handler { + if logger == nil { + logger = zap.NewNop() + } + return &slogZapHandler{ + logger: logger, + attrs: make([]slog.Attr, 0, 8), + groups: make([]string, 0, 4), + } +} + +func (h *slogZapHandler) Enabled(_ context.Context, level slog.Level) bool { + switch { + case level >= slog.LevelError: + return h.logger.Core().Enabled(LevelError) + case level >= slog.LevelWarn: + return h.logger.Core().Enabled(LevelWarn) + case level <= slog.LevelDebug: + return h.logger.Core().Enabled(LevelDebug) + default: + return h.logger.Core().Enabled(LevelInfo) + } +} + +func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error { + fields := make([]zap.Field, 0, len(h.attrs)+record.NumAttrs()+3) + fields = append(fields, slogAttrsToZapFields(h.groups, h.attrs)...) + record.Attrs(func(attr slog.Attr) bool { + fields = append(fields, slogAttrToZapField(h.groups, attr)) + return true + }) + + switch { + case record.Level >= slog.LevelError: + h.logger.Error(record.Message, fields...) + case record.Level >= slog.LevelWarn: + h.logger.Warn(record.Message, fields...) + case record.Level <= slog.LevelDebug: + h.logger.Debug(record.Message, fields...) + default: + h.logger.Info(record.Message, fields...) + } + return nil +} + +func (h *slogZapHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + next := *h + next.attrs = append(append([]slog.Attr{}, h.attrs...), attrs...) + return &next +} + +func (h *slogZapHandler) WithGroup(name string) slog.Handler { + name = strings.TrimSpace(name) + if name == "" { + return h + } + next := *h + next.groups = append(append([]string{}, h.groups...), name) + return &next +} + +func slogAttrsToZapFields(groups []string, attrs []slog.Attr) []zap.Field { + fields := make([]zap.Field, 0, len(attrs)) + for _, attr := range attrs { + fields = append(fields, slogAttrToZapField(groups, attr)) + } + return fields +} + +func slogAttrToZapField(groups []string, attr slog.Attr) zap.Field { + if len(groups) > 0 { + attr.Key = strings.Join(append(append([]string{}, groups...), attr.Key), ".") + } + value := attr.Value.Resolve() + switch value.Kind() { + case slog.KindBool: + return zap.Bool(attr.Key, value.Bool()) + case slog.KindInt64: + return zap.Int64(attr.Key, value.Int64()) + case slog.KindUint64: + return zap.Uint64(attr.Key, value.Uint64()) + case slog.KindFloat64: + return zap.Float64(attr.Key, value.Float64()) + case slog.KindDuration: + return zap.Duration(attr.Key, value.Duration()) + case slog.KindTime: + return zap.Time(attr.Key, value.Time()) + case slog.KindString: + return zap.String(attr.Key, value.String()) + case slog.KindGroup: + groupFields := make([]zap.Field, 0, len(value.Group())) + for _, nested := range value.Group() { + groupFields = append(groupFields, slogAttrToZapField(nil, nested)) + } + return zap.Object(attr.Key, zapObjectFields(groupFields)) + case slog.KindAny: + if t, ok := value.Any().(time.Time); ok { + return zap.Time(attr.Key, t) + } + return zap.Any(attr.Key, value.Any()) + default: + return zap.String(attr.Key, value.String()) + } +} + +type zapObjectFields []zap.Field + +func (z zapObjectFields) MarshalLogObject(enc zapcore.ObjectEncoder) error { + for _, field := range z { + field.AddTo(enc) + } + return nil +} diff --git a/backend/internal/pkg/logger/slog_handler_test.go b/backend/internal/pkg/logger/slog_handler_test.go new file mode 100644 index 00000000..d2b4208d --- /dev/null +++ b/backend/internal/pkg/logger/slog_handler_test.go @@ -0,0 +1,88 @@ +package logger + +import ( + "context" + "log/slog" + "testing" + "time" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type captureState struct { + writes []capturedWrite +} + +type capturedWrite struct { + fields []zapcore.Field +} + +type captureCore struct { + state *captureState + withFields []zapcore.Field +} + +func newCaptureCore() *captureCore { + return &captureCore{state: &captureState{}} +} + +func (c *captureCore) Enabled(zapcore.Level) bool { + return true +} + +func (c *captureCore) With(fields []zapcore.Field) zapcore.Core { + nextFields := make([]zapcore.Field, 0, len(c.withFields)+len(fields)) + nextFields = append(nextFields, c.withFields...) + nextFields = append(nextFields, fields...) + return &captureCore{ + state: c.state, + withFields: nextFields, + } +} + +func (c *captureCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry { + return ce.AddCore(entry, c) +} + +func (c *captureCore) Write(entry zapcore.Entry, fields []zapcore.Field) error { + allFields := make([]zapcore.Field, 0, len(c.withFields)+len(fields)) + allFields = append(allFields, c.withFields...) + allFields = append(allFields, fields...) + c.state.writes = append(c.state.writes, capturedWrite{ + fields: allFields, + }) + return nil +} + +func (c *captureCore) Sync() error { + return nil +} + +func TestSlogZapHandler_Handle_DoesNotAppendTimeField(t *testing.T) { + core := newCaptureCore() + handler := newSlogZapHandler(zap.New(core)) + + record := slog.NewRecord(time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC), slog.LevelInfo, "hello", 0) + record.AddAttrs(slog.String("component", "http.access")) + + if err := handler.Handle(context.Background(), record); err != nil { + t.Fatalf("handle slog record: %v", err) + } + if len(core.state.writes) != 1 { + t.Fatalf("write calls = %d, want 1", len(core.state.writes)) + } + + var hasComponent bool + for _, field := range core.state.writes[0].fields { + if field.Key == "time" { + t.Fatalf("unexpected duplicate time field in slog adapter output") + } + if field.Key == "component" { + hasComponent = true + } + } + if !hasComponent { + t.Fatalf("component field should be preserved") + } +} diff --git a/backend/internal/pkg/logger/stdlog_bridge_test.go b/backend/internal/pkg/logger/stdlog_bridge_test.go new file mode 100644 index 00000000..4482a2ec --- /dev/null +++ b/backend/internal/pkg/logger/stdlog_bridge_test.go @@ -0,0 +1,166 @@ +package logger + +import ( + "io" + "log" + "os" + "strings" + "testing" +) + +func TestInferStdLogLevel(t *testing.T) { + cases := []struct { + msg string + want Level + }{ + {msg: "Warning: queue full", want: LevelWarn}, + {msg: "Forward request failed: timeout", want: LevelError}, + {msg: "[ERROR] upstream unavailable", want: LevelError}, + {msg: "[OpenAI WS Mode] reconnect_retry account_id=22 retry=1 max_retries=5", want: LevelInfo}, + {msg: "service started", want: LevelInfo}, + {msg: "debug: cache miss", want: LevelDebug}, + } + + for _, tc := range cases { + got := inferStdLogLevel(tc.msg) + if got != tc.want { + t.Fatalf("inferStdLogLevel(%q)=%v want=%v", tc.msg, got, tc.want) + } + } +} + +func TestNormalizeStdLogMessage(t *testing.T) { + raw := " [TokenRefresh] cycle complete \n total=1 failed=0 \n" + got := normalizeStdLogMessage(raw) + want := "[TokenRefresh] cycle complete total=1 failed=0" + if got != want { + t.Fatalf("normalizeStdLogMessage()=%q want=%q", got, want) + } +} + +func TestStdLogBridgeRoutesLevels(t *testing.T) { + origStdout := os.Stdout + origStderr := os.Stderr + stdoutR, stdoutW, err := os.Pipe() + if err != nil { + t.Fatalf("create stdout pipe: %v", err) + } + stderrR, stderrW, err := os.Pipe() + if err != nil { + t.Fatalf("create stderr pipe: %v", err) + } + os.Stdout = stdoutW + os.Stderr = stderrW + t.Cleanup(func() { + os.Stdout = origStdout + os.Stderr = origStderr + _ = stdoutR.Close() + _ = stdoutW.Close() + _ = stderrR.Close() + _ = stderrW.Close() + }) + + if err := Init(InitOptions{ + Level: "debug", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: OutputOptions{ + ToStdout: true, + ToFile: false, + }, + Sampling: SamplingOptions{Enabled: false}, + }); err != nil { + t.Fatalf("Init() error: %v", err) + } + + log.Printf("service started") + log.Printf("Warning: queue full") + log.Printf("Forward request failed: timeout") + Sync() + + _ = stdoutW.Close() + _ = stderrW.Close() + stdoutBytes, _ := io.ReadAll(stdoutR) + stderrBytes, _ := io.ReadAll(stderrR) + stdoutText := string(stdoutBytes) + stderrText := string(stderrBytes) + + if !strings.Contains(stdoutText, "service started") { + t.Fatalf("stdout missing info log: %s", stdoutText) + } + if !strings.Contains(stderrText, "Warning: queue full") { + t.Fatalf("stderr missing warn log: %s", stderrText) + } + if !strings.Contains(stderrText, "Forward request failed: timeout") { + t.Fatalf("stderr missing error log: %s", stderrText) + } + if !strings.Contains(stderrText, "\"legacy_stdlog\":true") { + t.Fatalf("stderr missing legacy_stdlog marker: %s", stderrText) + } +} + +func TestLegacyPrintfRoutesLevels(t *testing.T) { + origStdout := os.Stdout + origStderr := os.Stderr + stdoutR, stdoutW, err := os.Pipe() + if err != nil { + t.Fatalf("create stdout pipe: %v", err) + } + stderrR, stderrW, err := os.Pipe() + if err != nil { + t.Fatalf("create stderr pipe: %v", err) + } + os.Stdout = stdoutW + os.Stderr = stderrW + t.Cleanup(func() { + os.Stdout = origStdout + os.Stderr = origStderr + _ = stdoutR.Close() + _ = stdoutW.Close() + _ = stderrR.Close() + _ = stderrW.Close() + }) + + if err := Init(InitOptions{ + Level: "debug", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: OutputOptions{ + ToStdout: true, + ToFile: false, + }, + Sampling: SamplingOptions{Enabled: false}, + }); err != nil { + t.Fatalf("Init() error: %v", err) + } + + LegacyPrintf("service.test", "request started") + LegacyPrintf("service.test", "Warning: queue full") + LegacyPrintf("service.test", "forward failed: timeout") + Sync() + + _ = stdoutW.Close() + _ = stderrW.Close() + stdoutBytes, _ := io.ReadAll(stdoutR) + stderrBytes, _ := io.ReadAll(stderrR) + stdoutText := string(stdoutBytes) + stderrText := string(stderrBytes) + + if !strings.Contains(stdoutText, "request started") { + t.Fatalf("stdout missing info log: %s", stdoutText) + } + if !strings.Contains(stderrText, "Warning: queue full") { + t.Fatalf("stderr missing warn log: %s", stderrText) + } + if !strings.Contains(stderrText, "forward failed: timeout") { + t.Fatalf("stderr missing error log: %s", stderrText) + } + if !strings.Contains(stderrText, "\"legacy_printf\":true") { + t.Fatalf("stderr missing legacy_printf marker: %s", stderrText) + } + if !strings.Contains(stderrText, "\"component\":\"service.test\"") { + t.Fatalf("stderr missing component field: %s", stderrText) + } +} diff --git a/backend/internal/pkg/oauth/oauth.go b/backend/internal/pkg/oauth/oauth.go index 33caffd7..cfc91bee 100644 --- a/backend/internal/pkg/oauth/oauth.go +++ b/backend/internal/pkg/oauth/oauth.go @@ -50,6 +50,7 @@ type OAuthSession struct { type SessionStore struct { mu sync.RWMutex sessions map[string]*OAuthSession + stopOnce sync.Once stopCh chan struct{} } @@ -65,7 +66,9 @@ func NewSessionStore() *SessionStore { // Stop stops the cleanup goroutine func (s *SessionStore) Stop() { - close(s.stopCh) + s.stopOnce.Do(func() { + close(s.stopCh) + }) } // Set stores a session diff --git a/backend/internal/pkg/oauth/oauth_test.go b/backend/internal/pkg/oauth/oauth_test.go new file mode 100644 index 00000000..9e59f0f0 --- /dev/null +++ b/backend/internal/pkg/oauth/oauth_test.go @@ -0,0 +1,43 @@ +package oauth + +import ( + "sync" + "testing" + "time" +) + +func TestSessionStore_Stop_Idempotent(t *testing.T) { + store := NewSessionStore() + + store.Stop() + store.Stop() + + select { + case <-store.stopCh: + // ok + case <-time.After(time.Second): + t.Fatal("stopCh 未关闭") + } +} + +func TestSessionStore_Stop_Concurrent(t *testing.T) { + store := NewSessionStore() + + var wg sync.WaitGroup + for range 50 { + wg.Add(1) + go func() { + defer wg.Done() + store.Stop() + }() + } + + wg.Wait() + + select { + case <-store.stopCh: + // ok + case <-time.After(time.Second): + t.Fatal("stopCh 未关闭") + } +} diff --git a/backend/internal/pkg/openai/constants.go b/backend/internal/pkg/openai/constants.go index fd24b11d..4bbc68e7 100644 --- a/backend/internal/pkg/openai/constants.go +++ b/backend/internal/pkg/openai/constants.go @@ -15,8 +15,8 @@ type Model struct { // DefaultModels OpenAI models list var DefaultModels = []Model{ - {ID: "gpt-5.3", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3"}, {ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"}, + {ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"}, {ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"}, {ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"}, {ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"}, diff --git a/backend/internal/pkg/openai/oauth.go b/backend/internal/pkg/openai/oauth.go index df972a13..8bdcbe16 100644 --- a/backend/internal/pkg/openai/oauth.go +++ b/backend/internal/pkg/openai/oauth.go @@ -17,6 +17,8 @@ import ( const ( // OAuth Client ID for OpenAI (Codex CLI official) ClientID = "app_EMoamEEZ73f0CkXaXp7hrann" + // OAuth Client ID for Sora mobile flow (aligned with sora2api) + SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK" // OAuth endpoints AuthorizeURL = "https://auth.openai.com/oauth/authorize" @@ -34,10 +36,18 @@ const ( SessionTTL = 30 * time.Minute ) +const ( + // OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client. + OAuthPlatformOpenAI = "openai" + // OAuthPlatformSora uses Sora OAuth client. + OAuthPlatformSora = "sora" +) + // OAuthSession stores OAuth flow state for OpenAI type OAuthSession struct { State string `json:"state"` CodeVerifier string `json:"code_verifier"` + ClientID string `json:"client_id,omitempty"` ProxyURL string `json:"proxy_url,omitempty"` RedirectURI string `json:"redirect_uri"` CreatedAt time.Time `json:"created_at"` @@ -47,6 +57,7 @@ type OAuthSession struct { type SessionStore struct { mu sync.RWMutex sessions map[string]*OAuthSession + stopOnce sync.Once stopCh chan struct{} } @@ -92,7 +103,9 @@ func (s *SessionStore) Delete(sessionID string) { // Stop stops the cleanup goroutine func (s *SessionStore) Stop() { - close(s.stopCh) + s.stopOnce.Do(func() { + close(s.stopCh) + }) } // cleanup removes expired sessions periodically @@ -169,13 +182,20 @@ func base64URLEncode(data []byte) string { // BuildAuthorizationURL builds the OpenAI OAuth authorization URL func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string { + return BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, OAuthPlatformOpenAI) +} + +// BuildAuthorizationURLForPlatform builds authorization URL by platform. +func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platform string) string { if redirectURI == "" { redirectURI = DefaultRedirectURI } + clientID, codexFlow := OAuthClientConfigByPlatform(platform) + params := url.Values{} params.Set("response_type", "code") - params.Set("client_id", ClientID) + params.Set("client_id", clientID) params.Set("redirect_uri", redirectURI) params.Set("scope", DefaultScopes) params.Set("state", state) @@ -183,11 +203,25 @@ func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string { params.Set("code_challenge_method", "S256") // OpenAI specific parameters params.Set("id_token_add_organizations", "true") - params.Set("codex_cli_simplified_flow", "true") + if codexFlow { + params.Set("codex_cli_simplified_flow", "true") + } return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode()) } +// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled. +// Sora 授权流程复用 Codex CLI 的 client_id(支持 localhost redirect_uri), +// 但不启用 codex_cli_simplified_flow;拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。 +func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) { + switch strings.ToLower(strings.TrimSpace(platform)) { + case OAuthPlatformSora: + return ClientID, false + default: + return ClientID, true + } +} + // TokenRequest represents the token exchange request body type TokenRequest struct { GrantType string `json:"grant_type"` @@ -291,9 +325,11 @@ func (r *RefreshTokenRequest) ToFormData() string { return params.Encode() } -// ParseIDToken parses the ID Token JWT and extracts claims -// Note: This does NOT verify the signature - it only decodes the payload -// For production, you should verify the token signature using OpenAI's public keys +// ParseIDToken parses the ID Token JWT and extracts claims. +// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。 +// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名: +// +// https://auth.openai.com/.well-known/jwks.json func ParseIDToken(idToken string) (*IDTokenClaims, error) { parts := strings.Split(idToken, ".") if len(parts) != 3 { @@ -324,6 +360,13 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) { return nil, fmt.Errorf("failed to parse JWT claims: %w", err) } + // 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌) + const clockSkewTolerance = 120 // 秒 + now := time.Now().Unix() + if claims.Exp > 0 && now > claims.Exp+clockSkewTolerance { + return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance) + } + return &claims, nil } diff --git a/backend/internal/pkg/openai/oauth_test.go b/backend/internal/pkg/openai/oauth_test.go new file mode 100644 index 00000000..2970addf --- /dev/null +++ b/backend/internal/pkg/openai/oauth_test.go @@ -0,0 +1,82 @@ +package openai + +import ( + "net/url" + "sync" + "testing" + "time" +) + +func TestSessionStore_Stop_Idempotent(t *testing.T) { + store := NewSessionStore() + + store.Stop() + store.Stop() + + select { + case <-store.stopCh: + // ok + case <-time.After(time.Second): + t.Fatal("stopCh 未关闭") + } +} + +func TestSessionStore_Stop_Concurrent(t *testing.T) { + store := NewSessionStore() + + var wg sync.WaitGroup + for range 50 { + wg.Add(1) + go func() { + defer wg.Done() + store.Stop() + }() + } + + wg.Wait() + + select { + case <-store.stopCh: + // ok + case <-time.After(time.Second): + t.Fatal("stopCh 未关闭") + } +} + +func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) { + authURL := BuildAuthorizationURLForPlatform("state-1", "challenge-1", DefaultRedirectURI, OAuthPlatformOpenAI) + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("Parse URL failed: %v", err) + } + q := parsed.Query() + if got := q.Get("client_id"); got != ClientID { + t.Fatalf("client_id mismatch: got=%q want=%q", got, ClientID) + } + if got := q.Get("codex_cli_simplified_flow"); got != "true" { + t.Fatalf("codex flow mismatch: got=%q want=true", got) + } + if got := q.Get("id_token_add_organizations"); got != "true" { + t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got) + } +} + +// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id, +// 但不启用 codex_cli_simplified_flow。 +func TestBuildAuthorizationURLForPlatform_Sora(t *testing.T) { + authURL := BuildAuthorizationURLForPlatform("state-2", "challenge-2", DefaultRedirectURI, OAuthPlatformSora) + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("Parse URL failed: %v", err) + } + q := parsed.Query() + if got := q.Get("client_id"); got != ClientID { + t.Fatalf("client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)", got, ClientID) + } + if got := q.Get("codex_cli_simplified_flow"); got != "" { + t.Fatalf("codex flow should be empty for sora, got=%q", got) + } + if got := q.Get("id_token_add_organizations"); got != "true" { + t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got) + } +} diff --git a/backend/internal/pkg/openai/request.go b/backend/internal/pkg/openai/request.go index 5b049ddc..c24d1273 100644 --- a/backend/internal/pkg/openai/request.go +++ b/backend/internal/pkg/openai/request.go @@ -1,5 +1,7 @@ package openai +import "strings" + // CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns // Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2" var CodexCLIUserAgentPrefixes = []string{ @@ -7,10 +9,67 @@ var CodexCLIUserAgentPrefixes = []string{ "codex_cli_rs/", } +// CodexOfficialClientUserAgentPrefixes matches Codex 官方客户端家族 User-Agent 前缀。 +// 该列表仅用于 OpenAI OAuth `codex_cli_only` 访问限制判定。 +var CodexOfficialClientUserAgentPrefixes = []string{ + "codex_cli_rs/", + "codex_vscode/", + "codex_app/", + "codex_chatgpt_desktop/", + "codex_atlas/", + "codex_exec/", + "codex_sdk_ts/", + "codex ", +} + +// CodexOfficialClientOriginatorPrefixes matches Codex 官方客户端家族 originator 前缀。 +// 说明:OpenAI 官方 Codex 客户端并不只使用固定的 codex_app 标识。 +// 例如 codex_cli_rs、codex_vscode、codex_chatgpt_desktop、codex_atlas、codex_exec、codex_sdk_ts 等。 +var CodexOfficialClientOriginatorPrefixes = []string{ + "codex_", + "codex ", +} + // IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request func IsCodexCLIRequest(userAgent string) bool { - for _, prefix := range CodexCLIUserAgentPrefixes { - if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix { + ua := normalizeCodexClientHeader(userAgent) + if ua == "" { + return false + } + return matchCodexClientHeaderPrefixes(ua, CodexCLIUserAgentPrefixes) +} + +// IsCodexOfficialClientRequest checks if the User-Agent indicates a Codex 官方客户端请求。 +// 与 IsCodexCLIRequest 解耦,避免影响历史兼容逻辑。 +func IsCodexOfficialClientRequest(userAgent string) bool { + ua := normalizeCodexClientHeader(userAgent) + if ua == "" { + return false + } + return matchCodexClientHeaderPrefixes(ua, CodexOfficialClientUserAgentPrefixes) +} + +// IsCodexOfficialClientOriginator checks if originator indicates a Codex 官方客户端请求。 +func IsCodexOfficialClientOriginator(originator string) bool { + v := normalizeCodexClientHeader(originator) + if v == "" { + return false + } + return matchCodexClientHeaderPrefixes(v, CodexOfficialClientOriginatorPrefixes) +} + +func normalizeCodexClientHeader(value string) string { + return strings.ToLower(strings.TrimSpace(value)) +} + +func matchCodexClientHeaderPrefixes(value string, prefixes []string) bool { + for _, prefix := range prefixes { + normalizedPrefix := normalizeCodexClientHeader(prefix) + if normalizedPrefix == "" { + continue + } + // 优先前缀匹配;若 UA/Originator 被网关拼接为复合字符串时,退化为包含匹配。 + if strings.HasPrefix(value, normalizedPrefix) || strings.Contains(value, normalizedPrefix) { return true } } diff --git a/backend/internal/pkg/openai/request_test.go b/backend/internal/pkg/openai/request_test.go new file mode 100644 index 00000000..508bf561 --- /dev/null +++ b/backend/internal/pkg/openai/request_test.go @@ -0,0 +1,87 @@ +package openai + +import "testing" + +func TestIsCodexCLIRequest(t *testing.T) { + tests := []struct { + name string + ua string + want bool + }{ + {name: "codex_cli_rs 前缀", ua: "codex_cli_rs/0.1.0", want: true}, + {name: "codex_vscode 前缀", ua: "codex_vscode/1.2.3", want: true}, + {name: "大小写混合", ua: "Codex_CLI_Rs/0.1.0", want: true}, + {name: "复合 UA 包含 codex", ua: "Mozilla/5.0 codex_cli_rs/0.1.0", want: true}, + {name: "空白包裹", ua: " codex_vscode/1.2.3 ", want: true}, + {name: "非 codex", ua: "curl/8.0.1", want: false}, + {name: "空字符串", ua: "", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsCodexCLIRequest(tt.ua) + if got != tt.want { + t.Fatalf("IsCodexCLIRequest(%q) = %v, want %v", tt.ua, got, tt.want) + } + }) + } +} + +func TestIsCodexOfficialClientRequest(t *testing.T) { + tests := []struct { + name string + ua string + want bool + }{ + {name: "codex_cli_rs 前缀", ua: "codex_cli_rs/0.98.0", want: true}, + {name: "codex_vscode 前缀", ua: "codex_vscode/1.0.0", want: true}, + {name: "codex_app 前缀", ua: "codex_app/0.1.0", want: true}, + {name: "codex_chatgpt_desktop 前缀", ua: "codex_chatgpt_desktop/1.0.0", want: true}, + {name: "codex_atlas 前缀", ua: "codex_atlas/1.0.0", want: true}, + {name: "codex_exec 前缀", ua: "codex_exec/0.1.0", want: true}, + {name: "codex_sdk_ts 前缀", ua: "codex_sdk_ts/0.1.0", want: true}, + {name: "Codex 桌面 UA", ua: "Codex Desktop/1.2.3", want: true}, + {name: "复合 UA 包含 codex_app", ua: "Mozilla/5.0 codex_app/0.1.0", want: true}, + {name: "大小写混合", ua: "Codex_VSCode/1.2.3", want: true}, + {name: "非 codex", ua: "curl/8.0.1", want: false}, + {name: "空字符串", ua: "", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsCodexOfficialClientRequest(tt.ua) + if got != tt.want { + t.Fatalf("IsCodexOfficialClientRequest(%q) = %v, want %v", tt.ua, got, tt.want) + } + }) + } +} + +func TestIsCodexOfficialClientOriginator(t *testing.T) { + tests := []struct { + name string + originator string + want bool + }{ + {name: "codex_cli_rs", originator: "codex_cli_rs", want: true}, + {name: "codex_vscode", originator: "codex_vscode", want: true}, + {name: "codex_app", originator: "codex_app", want: true}, + {name: "codex_chatgpt_desktop", originator: "codex_chatgpt_desktop", want: true}, + {name: "codex_atlas", originator: "codex_atlas", want: true}, + {name: "codex_exec", originator: "codex_exec", want: true}, + {name: "codex_sdk_ts", originator: "codex_sdk_ts", want: true}, + {name: "Codex 前缀", originator: "Codex Desktop", want: true}, + {name: "空白包裹", originator: " codex_vscode ", want: true}, + {name: "非 codex", originator: "my_client", want: false}, + {name: "空字符串", originator: "", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsCodexOfficialClientOriginator(tt.originator) + if got != tt.want { + t.Fatalf("IsCodexOfficialClientOriginator(%q) = %v, want %v", tt.originator, got, tt.want) + } + }) + } +} diff --git a/backend/internal/pkg/proxyurl/parse.go b/backend/internal/pkg/proxyurl/parse.go new file mode 100644 index 00000000..217556f2 --- /dev/null +++ b/backend/internal/pkg/proxyurl/parse.go @@ -0,0 +1,66 @@ +// Package proxyurl 提供代理 URL 的统一验证(fail-fast,无效代理不回退直连) +// +// 所有需要解析代理 URL 的地方必须通过此包的 Parse 函数。 +// 直接使用 url.Parse 处理代理 URL 是被禁止的。 +// 这确保了 fail-fast 行为:无效代理配置在创建时立即失败, +// 而不是在运行时静默回退到直连(产生 IP 关联风险)。 +package proxyurl + +import ( + "fmt" + "net/url" + "strings" +) + +// allowedSchemes 代理协议白名单 +var allowedSchemes = map[string]bool{ + "http": true, + "https": true, + "socks5": true, + "socks5h": true, +} + +// Parse 解析并验证代理 URL。 +// +// 语义: +// - 空字符串 → ("", nil, nil),表示直连 +// - 非空且有效 → (trimmed, *url.URL, nil) +// - 非空但无效 → ("", nil, error),fail-fast 不回退 +// +// 验证规则: +// - TrimSpace 后为空视为直连 +// - url.Parse 失败返回 error(不含原始 URL,防凭据泄露) +// - Host 为空返回 error(用 Redacted() 脱敏) +// - Scheme 必须为 http/https/socks5/socks5h +// - socks5:// 自动升级为 socks5h://(确保 DNS 由代理端解析,防止 DNS 泄漏) +func Parse(raw string) (trimmed string, parsed *url.URL, err error) { + trimmed = strings.TrimSpace(raw) + if trimmed == "" { + return "", nil, nil + } + + parsed, err = url.Parse(trimmed) + if err != nil { + // 不使用 %w 包装,避免 url.Parse 的底层错误消息泄漏原始 URL(可能含凭据) + return "", nil, fmt.Errorf("invalid proxy URL: %v", err) + } + + if parsed.Host == "" || parsed.Hostname() == "" { + return "", nil, fmt.Errorf("proxy URL missing host: %s", parsed.Redacted()) + } + + scheme := strings.ToLower(parsed.Scheme) + if !allowedSchemes[scheme] { + return "", nil, fmt.Errorf("unsupported proxy scheme %q (allowed: http, https, socks5, socks5h)", scheme) + } + + // 自动升级 socks5 → socks5h,确保 DNS 由代理端解析,防止 DNS 泄漏。 + // Go 的 golang.org/x/net/proxy 对 socks5:// 默认在客户端本地解析 DNS, + // 仅 socks5h:// 才将域名发送给代理端做远程 DNS 解析。 + if scheme == "socks5" { + parsed.Scheme = "socks5h" + trimmed = parsed.String() + } + + return trimmed, parsed, nil +} diff --git a/backend/internal/pkg/proxyurl/parse_test.go b/backend/internal/pkg/proxyurl/parse_test.go new file mode 100644 index 00000000..5fb57c16 --- /dev/null +++ b/backend/internal/pkg/proxyurl/parse_test.go @@ -0,0 +1,215 @@ +package proxyurl + +import ( + "strings" + "testing" +) + +func TestParse_空字符串直连(t *testing.T) { + trimmed, parsed, err := Parse("") + if err != nil { + t.Fatalf("空字符串应直连: %v", err) + } + if trimmed != "" { + t.Errorf("trimmed 应为空: got %q", trimmed) + } + if parsed != nil { + t.Errorf("parsed 应为 nil: got %v", parsed) + } +} + +func TestParse_空白字符串直连(t *testing.T) { + trimmed, parsed, err := Parse(" ") + if err != nil { + t.Fatalf("空白字符串应直连: %v", err) + } + if trimmed != "" { + t.Errorf("trimmed 应为空: got %q", trimmed) + } + if parsed != nil { + t.Errorf("parsed 应为 nil: got %v", parsed) + } +} + +func TestParse_有效HTTP代理(t *testing.T) { + trimmed, parsed, err := Parse("http://proxy.example.com:8080") + if err != nil { + t.Fatalf("有效 HTTP 代理应成功: %v", err) + } + if trimmed != "http://proxy.example.com:8080" { + t.Errorf("trimmed 不匹配: got %q", trimmed) + } + if parsed == nil { + t.Fatal("parsed 不应为 nil") + } + if parsed.Host != "proxy.example.com:8080" { + t.Errorf("Host 不匹配: got %q", parsed.Host) + } +} + +func TestParse_有效HTTPS代理(t *testing.T) { + _, parsed, err := Parse("https://proxy.example.com:443") + if err != nil { + t.Fatalf("有效 HTTPS 代理应成功: %v", err) + } + if parsed.Scheme != "https" { + t.Errorf("Scheme 不匹配: got %q", parsed.Scheme) + } +} + +func TestParse_有效SOCKS5代理_自动升级为SOCKS5H(t *testing.T) { + trimmed, parsed, err := Parse("socks5://127.0.0.1:1080") + if err != nil { + t.Fatalf("有效 SOCKS5 代理应成功: %v", err) + } + // socks5 自动升级为 socks5h,确保 DNS 由代理端解析 + if trimmed != "socks5h://127.0.0.1:1080" { + t.Errorf("trimmed 应升级为 socks5h: got %q", trimmed) + } + if parsed.Scheme != "socks5h" { + t.Errorf("Scheme 应升级为 socks5h: got %q", parsed.Scheme) + } +} + +func TestParse_无效URL(t *testing.T) { + _, _, err := Parse("://invalid") + if err == nil { + t.Fatal("无效 URL 应返回错误") + } + if !strings.Contains(err.Error(), "invalid proxy URL") { + t.Errorf("错误信息应包含 'invalid proxy URL': got %s", err.Error()) + } +} + +func TestParse_缺少Host(t *testing.T) { + _, _, err := Parse("http://") + if err == nil { + t.Fatal("缺少 host 应返回错误") + } + if !strings.Contains(err.Error(), "missing host") { + t.Errorf("错误信息应包含 'missing host': got %s", err.Error()) + } +} + +func TestParse_不支持的Scheme(t *testing.T) { + _, _, err := Parse("ftp://proxy.example.com:21") + if err == nil { + t.Fatal("不支持的 scheme 应返回错误") + } + if !strings.Contains(err.Error(), "unsupported proxy scheme") { + t.Errorf("错误信息应包含 'unsupported proxy scheme': got %s", err.Error()) + } +} + +func TestParse_含密码URL脱敏(t *testing.T) { + // 场景 1: 带密码的 socks5 URL 应成功解析并升级为 socks5h + trimmed, parsed, err := Parse("socks5://user:secret_password@proxy.local:1080") + if err != nil { + t.Fatalf("含密码的有效 URL 应成功: %v", err) + } + if trimmed == "" || parsed == nil { + t.Fatal("应返回非空结果") + } + if parsed.Scheme != "socks5h" { + t.Errorf("Scheme 应升级为 socks5h: got %q", parsed.Scheme) + } + if !strings.HasPrefix(trimmed, "socks5h://") { + t.Errorf("trimmed 应以 socks5h:// 开头: got %q", trimmed) + } + if parsed.User == nil { + t.Error("升级后应保留 UserInfo") + } + + // 场景 2: 带密码但缺少 host(触发 Redacted 脱敏路径) + _, _, err = Parse("http://user:secret_password@:0/") + if err == nil { + t.Fatal("缺少 host 应返回错误") + } + if strings.Contains(err.Error(), "secret_password") { + t.Error("错误信息不应包含明文密码") + } + if !strings.Contains(err.Error(), "missing host") { + t.Errorf("错误信息应包含 'missing host': got %s", err.Error()) + } +} + +func TestParse_带空白的有效URL(t *testing.T) { + trimmed, parsed, err := Parse(" http://proxy.example.com:8080 ") + if err != nil { + t.Fatalf("带空白的有效 URL 应成功: %v", err) + } + if trimmed != "http://proxy.example.com:8080" { + t.Errorf("trimmed 应去除空白: got %q", trimmed) + } + if parsed == nil { + t.Fatal("parsed 不应为 nil") + } +} + +func TestParse_Scheme大小写不敏感(t *testing.T) { + // 大写 SOCKS5 应被接受并升级为 socks5h + trimmed, parsed, err := Parse("SOCKS5://proxy.example.com:1080") + if err != nil { + t.Fatalf("大写 SOCKS5 应被接受: %v", err) + } + if parsed.Scheme != "socks5h" { + t.Errorf("大写 SOCKS5 Scheme 应升级为 socks5h: got %q", parsed.Scheme) + } + if !strings.HasPrefix(trimmed, "socks5h://") { + t.Errorf("大写 SOCKS5 trimmed 应升级为 socks5h://: got %q", trimmed) + } + + // 大写 HTTP 应被接受(不变) + _, _, err = Parse("HTTP://proxy.example.com:8080") + if err != nil { + t.Fatalf("大写 HTTP 应被接受: %v", err) + } +} + +func TestParse_带认证的有效代理(t *testing.T) { + trimmed, parsed, err := Parse("http://user:pass@proxy.example.com:8080") + if err != nil { + t.Fatalf("带认证的代理 URL 应成功: %v", err) + } + if parsed.User == nil { + t.Error("应保留 UserInfo") + } + if trimmed != "http://user:pass@proxy.example.com:8080" { + t.Errorf("trimmed 不匹配: got %q", trimmed) + } +} + +func TestParse_IPv6地址(t *testing.T) { + trimmed, parsed, err := Parse("http://[::1]:8080") + if err != nil { + t.Fatalf("IPv6 代理 URL 应成功: %v", err) + } + if parsed.Hostname() != "::1" { + t.Errorf("Hostname 不匹配: got %q", parsed.Hostname()) + } + if trimmed != "http://[::1]:8080" { + t.Errorf("trimmed 不匹配: got %q", trimmed) + } +} + +func TestParse_SOCKS5H保持不变(t *testing.T) { + trimmed, parsed, err := Parse("socks5h://proxy.local:1080") + if err != nil { + t.Fatalf("有效 SOCKS5H 代理应成功: %v", err) + } + // socks5h 不需要升级,应保持原样 + if trimmed != "socks5h://proxy.local:1080" { + t.Errorf("trimmed 不应变化: got %q", trimmed) + } + if parsed.Scheme != "socks5h" { + t.Errorf("Scheme 应保持 socks5h: got %q", parsed.Scheme) + } +} + +func TestParse_无Scheme裸地址(t *testing.T) { + // 无 scheme 的裸地址,Go url.Parse 将其视为 path,Host 为空 + _, _, err := Parse("proxy.example.com:8080") + if err == nil { + t.Fatal("无 scheme 的裸地址应返回错误") + } +} diff --git a/backend/internal/pkg/proxyutil/dialer.go b/backend/internal/pkg/proxyutil/dialer.go index 91b224a2..e437cae3 100644 --- a/backend/internal/pkg/proxyutil/dialer.go +++ b/backend/internal/pkg/proxyutil/dialer.go @@ -2,7 +2,11 @@ // // 支持的代理协议: // - HTTP/HTTPS: 通过 Transport.Proxy 设置 -// - SOCKS5/SOCKS5H: 通过 Transport.DialContext 设置(服务端解析 DNS) +// - SOCKS5: 通过 Transport.DialContext 设置(客户端本地解析 DNS) +// - SOCKS5H: 通过 Transport.DialContext 设置(代理端远程解析 DNS,推荐) +// +// 注意:proxyurl.Parse() 会自动将 socks5:// 升级为 socks5h://, +// 确保 DNS 也由代理端解析,防止 DNS 泄漏。 package proxyutil import ( @@ -20,7 +24,8 @@ import ( // // 支持的协议: // - http/https: 设置 transport.Proxy -// - socks5/socks5h: 设置 transport.DialContext(由代理服务端解析 DNS) +// - socks5: 设置 transport.DialContext(客户端本地解析 DNS) +// - socks5h: 设置 transport.DialContext(代理端远程解析 DNS,推荐) // // 参数: // - transport: 需要配置的 http.Transport diff --git a/backend/internal/pkg/response/response.go b/backend/internal/pkg/response/response.go index c5b41d6e..0519c2cc 100644 --- a/backend/internal/pkg/response/response.go +++ b/backend/internal/pkg/response/response.go @@ -7,6 +7,7 @@ import ( "net/http" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/util/logredact" "github.com/gin-gonic/gin" ) @@ -78,7 +79,7 @@ func ErrorFrom(c *gin.Context, err error) bool { // Log internal errors with full details for debugging if statusCode >= 500 && c.Request != nil { - log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, err.Error()) + log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, logredact.RedactText(err.Error())) } ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata) diff --git a/backend/internal/pkg/response/response_test.go b/backend/internal/pkg/response/response_test.go index ef31ca3c..0debce5f 100644 --- a/backend/internal/pkg/response/response_test.go +++ b/backend/internal/pkg/response/response_test.go @@ -14,6 +14,44 @@ import ( "github.com/stretchr/testify/require" ) +// ---------- 辅助函数 ---------- + +// parseResponseBody 从 httptest.ResponseRecorder 中解析 JSON 响应体 +func parseResponseBody(t *testing.T, w *httptest.ResponseRecorder) Response { + t.Helper() + var got Response + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got)) + return got +} + +// parsePaginatedBody 从响应体中解析分页数据(Data 字段是 PaginatedData) +func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, PaginatedData) { + t.Helper() + // 先用 raw json 解析,因为 Data 是 any 类型 + var raw struct { + Code int `json:"code"` + Message string `json:"message"` + Reason string `json:"reason,omitempty"` + Data json.RawMessage `json:"data,omitempty"` + } + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw)) + + var pd PaginatedData + require.NoError(t, json.Unmarshal(raw.Data, &pd)) + + return Response{Code: raw.Code, Message: raw.Message, Reason: raw.Reason}, pd +} + +// newContextWithQuery 创建一个带有 URL query 参数的 gin.Context 用于测试 ParsePagination +func newContextWithQuery(query string) (*httptest.ResponseRecorder, *gin.Context) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/?"+query, nil) + return w, c +} + +// ---------- 现有测试 ---------- + func TestErrorWithDetails(t *testing.T) { gin.SetMode(gin.TestMode) @@ -169,3 +207,582 @@ func TestErrorFrom(t *testing.T) { }) } } + +// ---------- 新增测试 ---------- + +func TestSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + data any + wantCode int + wantBody Response + }{ + { + name: "返回字符串数据", + data: "hello", + wantCode: http.StatusOK, + wantBody: Response{Code: 0, Message: "success", Data: "hello"}, + }, + { + name: "返回nil数据", + data: nil, + wantCode: http.StatusOK, + wantBody: Response{Code: 0, Message: "success"}, + }, + { + name: "返回map数据", + data: map[string]string{"key": "value"}, + wantCode: http.StatusOK, + wantBody: Response{Code: 0, Message: "success"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Success(c, tt.data) + + require.Equal(t, tt.wantCode, w.Code) + + // 只验证 code 和 message,data 字段类型在 JSON 反序列化时会变成 map/slice + got := parseResponseBody(t, w) + require.Equal(t, 0, got.Code) + require.Equal(t, "success", got.Message) + + if tt.data == nil { + require.Nil(t, got.Data) + } else { + require.NotNil(t, got.Data) + } + }) + } +} + +func TestCreated(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + data any + wantCode int + }{ + { + name: "创建成功_返回数据", + data: map[string]int{"id": 42}, + wantCode: http.StatusCreated, + }, + { + name: "创建成功_nil数据", + data: nil, + wantCode: http.StatusCreated, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Created(c, tt.data) + + require.Equal(t, tt.wantCode, w.Code) + + got := parseResponseBody(t, w) + require.Equal(t, 0, got.Code) + require.Equal(t, "success", got.Message) + }) + } +} + +func TestError(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + statusCode int + message string + }{ + { + name: "400错误", + statusCode: http.StatusBadRequest, + message: "bad request", + }, + { + name: "500错误", + statusCode: http.StatusInternalServerError, + message: "internal error", + }, + { + name: "自定义状态码", + statusCode: 418, + message: "I'm a teapot", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Error(c, tt.statusCode, tt.message) + + require.Equal(t, tt.statusCode, w.Code) + + got := parseResponseBody(t, w) + require.Equal(t, tt.statusCode, got.Code) + require.Equal(t, tt.message, got.Message) + require.Empty(t, got.Reason) + require.Nil(t, got.Metadata) + require.Nil(t, got.Data) + }) + } +} + +func TestBadRequest(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + BadRequest(c, "参数无效") + + require.Equal(t, http.StatusBadRequest, w.Code) + got := parseResponseBody(t, w) + require.Equal(t, http.StatusBadRequest, got.Code) + require.Equal(t, "参数无效", got.Message) +} + +func TestUnauthorized(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Unauthorized(c, "未登录") + + require.Equal(t, http.StatusUnauthorized, w.Code) + got := parseResponseBody(t, w) + require.Equal(t, http.StatusUnauthorized, got.Code) + require.Equal(t, "未登录", got.Message) +} + +func TestForbidden(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Forbidden(c, "无权限") + + require.Equal(t, http.StatusForbidden, w.Code) + got := parseResponseBody(t, w) + require.Equal(t, http.StatusForbidden, got.Code) + require.Equal(t, "无权限", got.Message) +} + +func TestNotFound(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + NotFound(c, "资源不存在") + + require.Equal(t, http.StatusNotFound, w.Code) + got := parseResponseBody(t, w) + require.Equal(t, http.StatusNotFound, got.Code) + require.Equal(t, "资源不存在", got.Message) +} + +func TestInternalError(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + InternalError(c, "服务器内部错误") + + require.Equal(t, http.StatusInternalServerError, w.Code) + got := parseResponseBody(t, w) + require.Equal(t, http.StatusInternalServerError, got.Code) + require.Equal(t, "服务器内部错误", got.Message) +} + +func TestPaginated(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + items any + total int64 + page int + pageSize int + wantPages int + wantTotal int64 + wantPage int + wantPageSize int + }{ + { + name: "标准分页_多页", + items: []string{"a", "b"}, + total: 25, + page: 1, + pageSize: 10, + wantPages: 3, + wantTotal: 25, + wantPage: 1, + wantPageSize: 10, + }, + { + name: "总数刚好整除", + items: []string{"a"}, + total: 20, + page: 2, + pageSize: 10, + wantPages: 2, + wantTotal: 20, + wantPage: 2, + wantPageSize: 10, + }, + { + name: "总数为0_pages至少为1", + items: []string{}, + total: 0, + page: 1, + pageSize: 10, + wantPages: 1, + wantTotal: 0, + wantPage: 1, + wantPageSize: 10, + }, + { + name: "单页数据", + items: []int{1, 2, 3}, + total: 3, + page: 1, + pageSize: 20, + wantPages: 1, + wantTotal: 3, + wantPage: 1, + wantPageSize: 20, + }, + { + name: "总数为1", + items: []string{"only"}, + total: 1, + page: 1, + pageSize: 10, + wantPages: 1, + wantTotal: 1, + wantPage: 1, + wantPageSize: 10, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Paginated(c, tt.items, tt.total, tt.page, tt.pageSize) + + require.Equal(t, http.StatusOK, w.Code) + + resp, pd := parsePaginatedBody(t, w) + require.Equal(t, 0, resp.Code) + require.Equal(t, "success", resp.Message) + require.Equal(t, tt.wantTotal, pd.Total) + require.Equal(t, tt.wantPage, pd.Page) + require.Equal(t, tt.wantPageSize, pd.PageSize) + require.Equal(t, tt.wantPages, pd.Pages) + }) + } +} + +func TestPaginatedWithResult(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + items any + pagination *PaginationResult + wantTotal int64 + wantPage int + wantPageSize int + wantPages int + }{ + { + name: "正常分页结果", + items: []string{"a", "b"}, + pagination: &PaginationResult{ + Total: 50, + Page: 3, + PageSize: 10, + Pages: 5, + }, + wantTotal: 50, + wantPage: 3, + wantPageSize: 10, + wantPages: 5, + }, + { + name: "pagination为nil_使用默认值", + items: []string{}, + pagination: nil, + wantTotal: 0, + wantPage: 1, + wantPageSize: 20, + wantPages: 1, + }, + { + name: "单页结果", + items: []int{1}, + pagination: &PaginationResult{ + Total: 1, + Page: 1, + PageSize: 20, + Pages: 1, + }, + wantTotal: 1, + wantPage: 1, + wantPageSize: 20, + wantPages: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + PaginatedWithResult(c, tt.items, tt.pagination) + + require.Equal(t, http.StatusOK, w.Code) + + resp, pd := parsePaginatedBody(t, w) + require.Equal(t, 0, resp.Code) + require.Equal(t, "success", resp.Message) + require.Equal(t, tt.wantTotal, pd.Total) + require.Equal(t, tt.wantPage, pd.Page) + require.Equal(t, tt.wantPageSize, pd.PageSize) + require.Equal(t, tt.wantPages, pd.Pages) + }) + } +} + +func TestParsePagination(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + query string + wantPage int + wantPageSize int + }{ + { + name: "无参数_使用默认值", + query: "", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "仅指定page", + query: "page=3", + wantPage: 3, + wantPageSize: 20, + }, + { + name: "仅指定page_size", + query: "page_size=50", + wantPage: 1, + wantPageSize: 50, + }, + { + name: "同时指定page和page_size", + query: "page=2&page_size=30", + wantPage: 2, + wantPageSize: 30, + }, + { + name: "使用limit代替page_size", + query: "limit=15", + wantPage: 1, + wantPageSize: 15, + }, + { + name: "page_size优先于limit", + query: "page_size=25&limit=50", + wantPage: 1, + wantPageSize: 25, + }, + { + name: "page为0_使用默认值", + query: "page=0", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "page_size超过1000_使用默认值", + query: "page_size=1001", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "page_size恰好1000_有效", + query: "page_size=1000", + wantPage: 1, + wantPageSize: 1000, + }, + { + name: "page为非数字_使用默认值", + query: "page=abc", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "page_size为非数字_使用默认值", + query: "page_size=xyz", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "limit为非数字_使用默认值", + query: "limit=abc", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "page_size为0_使用默认值", + query: "page_size=0", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "limit为0_使用默认值", + query: "limit=0", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "大页码", + query: "page=999&page_size=100", + wantPage: 999, + wantPageSize: 100, + }, + { + name: "page_size为1_最小有效值", + query: "page_size=1", + wantPage: 1, + wantPageSize: 1, + }, + { + name: "混合数字和字母的page", + query: "page=12a", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "limit超过1000_使用默认值", + query: "limit=2000", + wantPage: 1, + wantPageSize: 20, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, c := newContextWithQuery(tt.query) + + page, pageSize := ParsePagination(c) + + require.Equal(t, tt.wantPage, page, "page 不符合预期") + require.Equal(t, tt.wantPageSize, pageSize, "pageSize 不符合预期") + }) + } +} + +func Test_parseInt(t *testing.T) { + tests := []struct { + name string + input string + wantVal int + wantErr bool + }{ + { + name: "正常数字", + input: "123", + wantVal: 123, + wantErr: false, + }, + { + name: "零", + input: "0", + wantVal: 0, + wantErr: false, + }, + { + name: "单个数字", + input: "5", + wantVal: 5, + wantErr: false, + }, + { + name: "大数字", + input: "99999", + wantVal: 99999, + wantErr: false, + }, + { + name: "包含字母_返回0", + input: "abc", + wantVal: 0, + wantErr: false, + }, + { + name: "数字开头接字母_返回0", + input: "12a", + wantVal: 0, + wantErr: false, + }, + { + name: "包含负号_返回0", + input: "-1", + wantVal: 0, + wantErr: false, + }, + { + name: "包含小数点_返回0", + input: "1.5", + wantVal: 0, + wantErr: false, + }, + { + name: "包含空格_返回0", + input: "1 2", + wantVal: 0, + wantErr: false, + }, + { + name: "空字符串", + input: "", + wantVal: 0, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := parseInt(tt.input) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.Equal(t, tt.wantVal, val) + }) + } +} diff --git a/backend/internal/pkg/tlsfingerprint/dialer.go b/backend/internal/pkg/tlsfingerprint/dialer.go index 42510986..4f25a34a 100644 --- a/backend/internal/pkg/tlsfingerprint/dialer.go +++ b/backend/internal/pkg/tlsfingerprint/dialer.go @@ -268,8 +268,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st "cipher_suites", len(spec.CipherSuites), "extensions", len(spec.Extensions), "compression_methods", spec.CompressionMethods, - "tls_vers_max", fmt.Sprintf("0x%04x", spec.TLSVersMax), - "tls_vers_min", fmt.Sprintf("0x%04x", spec.TLSVersMin)) + "tls_vers_max", spec.TLSVersMax, + "tls_vers_min", spec.TLSVersMin) if d.profile != nil { slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE) @@ -286,7 +286,7 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st return nil, fmt.Errorf("apply TLS preset: %w", err) } - if err := tlsConn.Handshake(); err != nil { + if err := tlsConn.HandshakeContext(ctx); err != nil { slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err) _ = conn.Close() return nil, fmt.Errorf("TLS handshake failed: %w", err) @@ -294,8 +294,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st state := tlsConn.ConnectionState() slog.Debug("tls_fingerprint_socks5_handshake_success", - "version", fmt.Sprintf("0x%04x", state.Version), - "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite), + "version", state.Version, + "cipher_suite", state.CipherSuite, "alpn", state.NegotiatedProtocol) return tlsConn, nil @@ -404,8 +404,8 @@ func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr stri state := tlsConn.ConnectionState() slog.Debug("tls_fingerprint_http_proxy_handshake_success", - "version", fmt.Sprintf("0x%04x", state.Version), - "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite), + "version", state.Version, + "cipher_suite", state.CipherSuite, "alpn", state.NegotiatedProtocol) return tlsConn, nil @@ -470,8 +470,8 @@ func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net. // Log successful handshake details state := tlsConn.ConnectionState() slog.Debug("tls_fingerprint_handshake_success", - "version", fmt.Sprintf("0x%04x", state.Version), - "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite), + "version", state.Version, + "cipher_suite", state.CipherSuite, "alpn", state.NegotiatedProtocol) return tlsConn, nil diff --git a/backend/internal/pkg/tlsfingerprint/dialer_integration_test.go b/backend/internal/pkg/tlsfingerprint/dialer_integration_test.go index eea74fcc..3f668fbe 100644 --- a/backend/internal/pkg/tlsfingerprint/dialer_integration_test.go +++ b/backend/internal/pkg/tlsfingerprint/dialer_integration_test.go @@ -30,7 +30,8 @@ func skipIfExternalServiceUnavailable(t *testing.T, err error) { strings.Contains(errStr, "connection refused") || strings.Contains(errStr, "no such host") || strings.Contains(errStr, "network is unreachable") || - strings.Contains(errStr, "timeout") { + strings.Contains(errStr, "timeout") || + strings.Contains(errStr, "deadline exceeded") { t.Skipf("skipping test: external service unavailable: %v", err) } t.Fatalf("failed to get fingerprint: %v", err) diff --git a/backend/internal/pkg/tlsfingerprint/dialer_test.go b/backend/internal/pkg/tlsfingerprint/dialer_test.go index dff7570f..6d3db174 100644 --- a/backend/internal/pkg/tlsfingerprint/dialer_test.go +++ b/backend/internal/pkg/tlsfingerprint/dialer_test.go @@ -1,3 +1,5 @@ +//go:build unit + // Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients. // // Unit tests for TLS fingerprint dialer. @@ -9,26 +11,161 @@ package tlsfingerprint import ( + "context" + "encoding/json" + "io" + "net/http" "net/url" + "os" + "strings" "testing" + "time" ) -// FingerprintResponse represents the response from tls.peet.ws/api/all. -type FingerprintResponse struct { - IP string `json:"ip"` - TLS TLSInfo `json:"tls"` - HTTP2 any `json:"http2"` +// TestDialerBasicConnection tests that the dialer can establish TLS connections. +func TestDialerBasicConnection(t *testing.T) { + skipNetworkTest(t) + + // Create a dialer with default profile + profile := &Profile{ + Name: "Test Profile", + EnableGREASE: false, + } + dialer := NewDialer(profile, nil) + + // Create HTTP client with custom TLS dialer + client := &http.Client{ + Transport: &http.Transport{ + DialTLSContext: dialer.DialTLSContext, + }, + Timeout: 30 * time.Second, + } + + // Make a request to a known HTTPS endpoint + resp, err := client.Get("https://www.google.com") + if err != nil { + t.Fatalf("failed to connect: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } } -// TLSInfo contains TLS fingerprint details. -type TLSInfo struct { - JA3 string `json:"ja3"` - JA3Hash string `json:"ja3_hash"` - JA4 string `json:"ja4"` - PeetPrint string `json:"peetprint"` - PeetPrintHash string `json:"peetprint_hash"` - ClientRandom string `json:"client_random"` - SessionID string `json:"session_id"` +// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value. +// This test uses tls.peet.ws to verify the fingerprint. +// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x) +// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP) +func TestJA3Fingerprint(t *testing.T) { + skipNetworkTest(t) + + profile := &Profile{ + Name: "Claude CLI Test", + EnableGREASE: false, + } + dialer := NewDialer(profile, nil) + + client := &http.Client{ + Transport: &http.Transport{ + DialTLSContext: dialer.DialTLSContext, + }, + Timeout: 30 * time.Second, + } + + // Use tls.peet.ws fingerprint detection API + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("failed to get fingerprint: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response: %v", err) + } + + var fpResp FingerprintResponse + if err := json.Unmarshal(body, &fpResp); err != nil { + t.Logf("Response body: %s", string(body)) + t.Fatalf("failed to parse fingerprint response: %v", err) + } + + // Log all fingerprint information + t.Logf("JA3: %s", fpResp.TLS.JA3) + t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash) + t.Logf("JA4: %s", fpResp.TLS.JA4) + t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint) + t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash) + + // Verify JA3 hash matches expected value + expectedJA3Hash := "1a28e69016765d92e3b381168d68922c" + if fpResp.TLS.JA3Hash == expectedJA3Hash { + t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash) + } else { + t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash) + } + + // Verify JA4 fingerprint + // JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash] + // Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP) + // The suffix _a33745022dd6_1f22a2ca17c4 should match + expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4" + if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) { + t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix) + } else { + t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix) + } + + // Verify JA4 prefix (t13d5911h1 or t13i5911h1) + // d = domain (SNI present), i = IP (no SNI) + // Since we connect to tls.peet.ws (domain), we expect 'd' + expectedJA4Prefix := "t13d5911h1" + if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) { + t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix) + } else { + // Also accept 'i' variant for IP connections + altPrefix := "t13i5911h1" + if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) { + t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix) + } else { + t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix) + } + } + + // Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning) + if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") { + t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites") + } else { + t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites") + } + + // Verify extension list (should be 11 extensions including SNI) + // Expected: 0-11-10-35-16-22-23-13-43-45-51 + expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51" + if strings.Contains(fpResp.TLS.JA3, expectedExtensions) { + t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions) + } else { + t.Logf("Warning: JA3 extension list may differ") + } +} + +func skipNetworkTest(t *testing.T) { + if testing.Short() { + t.Skip("跳过网络测试(short 模式)") + } + if os.Getenv("TLSFINGERPRINT_NETWORK_TESTS") != "1" { + t.Skip("跳过网络测试(需要设置 TLSFINGERPRINT_NETWORK_TESTS=1)") + } } // TestDialerWithProfile tests that different profiles produce different fingerprints. @@ -158,3 +295,137 @@ func mustParseURL(rawURL string) *url.URL { } return u } + +// TestProfileExpectation defines expected fingerprint values for a profile. +type TestProfileExpectation struct { + Profile *Profile + ExpectedJA3 string // Expected JA3 hash (empty = don't check) + ExpectedJA4 string // Expected full JA4 (empty = don't check) + JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check) +} + +// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws. +// Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/... +func TestAllProfiles(t *testing.T) { + skipNetworkTest(t) + + // Define all profiles to test with their expected fingerprints + // These profiles are from config.yaml gateway.tls_fingerprint.profiles + profiles := []TestProfileExpectation{ + { + // Linux x64 Node.js v22.17.1 + // Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c + // Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 + Profile: &Profile{ + Name: "linux_x64_node_v22171", + EnableGREASE: false, + CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255}, + Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260}, + PointFormats: []uint8{0, 1, 2}, + }, + JA4CipherHash: "a33745022dd6", // stable part + }, + { + // MacOS arm64 Node.js v22.18.0 + // Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea + // Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406 + Profile: &Profile{ + Name: "macos_arm64_node_v22180", + EnableGREASE: false, + CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255}, + Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260}, + PointFormats: []uint8{0, 1, 2}, + }, + JA4CipherHash: "a33745022dd6", // stable part (same cipher suites) + }, + } + + for _, tc := range profiles { + tc := tc // capture range variable + t.Run(tc.Profile.Name, func(t *testing.T) { + fp := fetchFingerprint(t, tc.Profile) + if fp == nil { + return // fetchFingerprint already called t.Fatal + } + + t.Logf("Profile: %s", tc.Profile.Name) + t.Logf(" JA3: %s", fp.JA3) + t.Logf(" JA3 Hash: %s", fp.JA3Hash) + t.Logf(" JA4: %s", fp.JA4) + t.Logf(" PeetPrint: %s", fp.PeetPrint) + t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash) + + // Verify expectations + if tc.ExpectedJA3 != "" { + if fp.JA3Hash == tc.ExpectedJA3 { + t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3) + } else { + t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3) + } + } + + if tc.ExpectedJA4 != "" { + if fp.JA4 == tc.ExpectedJA4 { + t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4) + } else { + t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4) + } + } + + // Check JA4 cipher hash (stable middle part) + // JA4 format: prefix_cipherHash_extHash + if tc.JA4CipherHash != "" { + if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") { + t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash) + } else { + t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash) + } + } + }) + } +} + +// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info. +func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo { + t.Helper() + + dialer := NewDialer(profile, nil) + client := &http.Client{ + Transport: &http.Transport{ + DialTLSContext: dialer.DialTLSContext, + }, + Timeout: 30 * time.Second, + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + return nil + } + req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("failed to get fingerprint: %v", err) + return nil + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response: %v", err) + return nil + } + + var fpResp FingerprintResponse + if err := json.Unmarshal(body, &fpResp); err != nil { + t.Logf("Response body: %s", string(body)) + t.Fatalf("failed to parse fingerprint response: %v", err) + return nil + } + + return &fpResp.TLS +} diff --git a/backend/internal/pkg/tlsfingerprint/test_types_test.go b/backend/internal/pkg/tlsfingerprint/test_types_test.go new file mode 100644 index 00000000..2bbf2d22 --- /dev/null +++ b/backend/internal/pkg/tlsfingerprint/test_types_test.go @@ -0,0 +1,20 @@ +package tlsfingerprint + +// FingerprintResponse represents the response from tls.peet.ws/api/all. +// 共享测试类型,供 unit 和 integration 测试文件使用。 +type FingerprintResponse struct { + IP string `json:"ip"` + TLS TLSInfo `json:"tls"` + HTTP2 any `json:"http2"` +} + +// TLSInfo contains TLS fingerprint details. +type TLSInfo struct { + JA3 string `json:"ja3"` + JA3Hash string `json:"ja3_hash"` + JA4 string `json:"ja4"` + PeetPrint string `json:"peetprint"` + PeetPrintHash string `json:"peetprint_hash"` + ClientRandom string `json:"client_random"` + SessionID string `json:"session_id"` +} diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index 2f6c7fe0..746188ea 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -78,6 +78,16 @@ type ModelStat struct { ActualCost float64 `json:"actual_cost"` // 实际扣除 } +// GroupStat represents usage statistics for a single group +type GroupStat struct { + GroupID int64 `json:"group_id"` + GroupName string `json:"group_name"` + Requests int64 `json:"requests"` + TotalTokens int64 `json:"total_tokens"` + Cost float64 `json:"cost"` // 标准计费 + ActualCost float64 `json:"actual_cost"` // 实际扣除 +} + // UserUsageTrendPoint represents user usage trend data point type UserUsageTrendPoint struct { Date string `json:"date"` @@ -139,10 +149,13 @@ type UsageLogFilters struct { AccountID int64 GroupID int64 Model string + RequestType *int16 Stream *bool BillingType *int8 StartTime *time.Time EndTime *time.Time + // ExactTotal requests exact COUNT(*) for pagination. Default false for fast large-table paging. + ExactTotal bool } // UsageStats represents usage statistics diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 134d92dd..6f0c5424 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -15,7 +15,6 @@ import ( "database/sql" "encoding/json" "errors" - "log" "strconv" "time" @@ -25,6 +24,7 @@ import ( dbgroup "github.com/Wei-Shaw/sub2api/ent/group" dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate" dbproxy "github.com/Wei-Shaw/sub2api/ent/proxy" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/lib/pq" @@ -50,11 +50,6 @@ type accountRepository struct { schedulerCache service.SchedulerCache } -type tempUnschedSnapshot struct { - until *time.Time - reason string -} - // NewAccountRepository 创建账户仓储实例。 // 这是对外暴露的构造函数,返回接口类型以便于依赖注入。 func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository { @@ -127,7 +122,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account account.CreatedAt = created.CreatedAt account.UpdatedAt = created.UpdatedAt if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil { - log.Printf("[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err) } return nil } @@ -189,11 +184,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi accountIDs = append(accountIDs, acc.ID) } - tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs) - if err != nil { - return nil, err - } - groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs) if err != nil { return nil, err @@ -220,10 +210,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi if ags, ok := accountGroupsByAccount[entAcc.ID]; ok { out.AccountGroups = ags } - if snap, ok := tempUnschedMap[entAcc.ID]; ok { - out.TempUnschedulableUntil = snap.until - out.TempUnschedulableReason = snap.reason - } outByID[entAcc.ID] = out } @@ -388,7 +374,7 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account } account.UpdatedAt = updated.UpdatedAt if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil { - log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err) } if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable { r.syncSchedulerAccountSnapshot(ctx, account.ID) @@ -429,7 +415,7 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error { } } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil { - log.Printf("[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err) } return nil } @@ -541,7 +527,7 @@ func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error }, } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, &id, nil, payload); err != nil { - log.Printf("[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err) } return nil } @@ -576,7 +562,7 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map } payload := map[string]any{"last_used": lastUsedPayload} if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, nil, nil, payload); err != nil { - log.Printf("[SchedulerOutbox] enqueue batch last used failed: err=%v", err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue batch last used failed: err=%v", err) } return nil } @@ -591,7 +577,7 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err) } r.syncSchedulerAccountSnapshot(ctx, id) return nil @@ -611,11 +597,48 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac } account, err := r.GetByID(ctx, accountID) if err != nil { - log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err) + logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err) return } if err := r.schedulerCache.SetAccount(ctx, account); err != nil { - log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err) + logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err) + } +} + +func (r *accountRepository) syncSchedulerAccountSnapshots(ctx context.Context, accountIDs []int64) { + if r == nil || r.schedulerCache == nil || len(accountIDs) == 0 { + return + } + + uniqueIDs := make([]int64, 0, len(accountIDs)) + seen := make(map[int64]struct{}, len(accountIDs)) + for _, id := range accountIDs { + if id <= 0 { + continue + } + if _, exists := seen[id]; exists { + continue + } + seen[id] = struct{}{} + uniqueIDs = append(uniqueIDs, id) + } + if len(uniqueIDs) == 0 { + return + } + + accounts, err := r.GetByIDs(ctx, uniqueIDs) + if err != nil { + logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot read failed: count=%d err=%v", len(uniqueIDs), err) + return + } + + for _, account := range accounts { + if account == nil { + continue + } + if err := r.schedulerCache.SetAccount(ctx, account); err != nil { + logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot write failed: id=%d err=%v", account.ID, err) + } } } @@ -649,7 +672,7 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i } payload := buildSchedulerGroupPayload([]int64{groupID}) if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil { - log.Printf("[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err) } return nil } @@ -666,7 +689,7 @@ func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, grou } payload := buildSchedulerGroupPayload([]int64{groupID}) if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil { - log.Printf("[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err) } return nil } @@ -739,7 +762,7 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro } payload := buildSchedulerGroupPayload(mergeGroupIDs(existingGroupIDs, groupIDs)) if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil { - log.Printf("[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err) } return nil } @@ -824,6 +847,51 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat return r.accountsToService(ctx, accounts) } +func (r *accountRepository) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + now := time.Now() + accounts, err := r.client.Account.Query(). + Where( + dbaccount.PlatformEQ(platform), + dbaccount.StatusEQ(service.StatusActive), + dbaccount.SchedulableEQ(true), + dbaccount.Not(dbaccount.HasAccountGroups()), + tempUnschedulablePredicate(), + notExpiredPredicate(now), + dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), + dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), + ). + Order(dbent.Asc(dbaccount.FieldPriority)). + All(ctx) + if err != nil { + return nil, err + } + return r.accountsToService(ctx, accounts) +} + +func (r *accountRepository) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) { + if len(platforms) == 0 { + return nil, nil + } + now := time.Now() + accounts, err := r.client.Account.Query(). + Where( + dbaccount.PlatformIn(platforms...), + dbaccount.StatusEQ(service.StatusActive), + dbaccount.SchedulableEQ(true), + dbaccount.Not(dbaccount.HasAccountGroups()), + tempUnschedulablePredicate(), + notExpiredPredicate(now), + dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)), + dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)), + ). + Order(dbent.Asc(dbaccount.FieldPriority)). + All(ctx) + if err != nil { + return nil, err + } + return r.accountsToService(ctx, accounts) +} + func (r *accountRepository) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) { if len(platforms) == 0 { return nil, nil @@ -847,7 +915,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err) } return nil } @@ -894,7 +962,7 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco return service.ErrAccountNotFound } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err) } return nil } @@ -908,7 +976,7 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err) } return nil } @@ -927,7 +995,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err) } r.syncSchedulerAccountSnapshot(ctx, id) return nil @@ -946,7 +1014,7 @@ func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64 return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err) } return nil } @@ -962,7 +1030,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err) } return nil } @@ -986,7 +1054,7 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id return service.ErrAccountNotFound } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err) } return nil } @@ -1010,7 +1078,7 @@ func (r *accountRepository) ClearModelRateLimits(ctx context.Context, id int64) return service.ErrAccountNotFound } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err) } return nil } @@ -1032,7 +1100,7 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s // 触发调度器缓存更新(仅当窗口时间有变化时) if start != nil || end != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err) } } return nil @@ -1047,7 +1115,7 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err) } if !schedulable { r.syncSchedulerAccountSnapshot(ctx, id) @@ -1075,7 +1143,7 @@ func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now ti } if rows > 0 { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventFullRebuild, nil, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err) } } return rows, nil @@ -1111,7 +1179,7 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m return service.ErrAccountNotFound } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err) } return nil } @@ -1205,7 +1273,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates if rows > 0 { payload := map[string]any{"account_ids": ids} if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil { - log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bulk update failed: err=%v", err) } shouldSync := false if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) { @@ -1215,9 +1283,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates shouldSync = true } if shouldSync { - for _, id := range ids { - r.syncSchedulerAccountSnapshot(ctx, id) - } + r.syncSchedulerAccountSnapshots(ctx, ids) } } return rows, nil @@ -1309,10 +1375,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d if err != nil { return nil, err } - tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs) - if err != nil { - return nil, err - } groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs) if err != nil { return nil, err @@ -1338,10 +1400,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d if ags, ok := accountGroupsByAccount[acc.ID]; ok { out.AccountGroups = ags } - if snap, ok := tempUnschedMap[acc.ID]; ok { - out.TempUnschedulableUntil = snap.until - out.TempUnschedulableReason = snap.reason - } outAccounts = append(outAccounts, *out) } @@ -1366,48 +1424,6 @@ func notExpiredPredicate(now time.Time) dbpredicate.Account { ) } -func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) { - out := make(map[int64]tempUnschedSnapshot) - if len(accountIDs) == 0 { - return out, nil - } - - rows, err := r.sql.QueryContext(ctx, ` - SELECT id, temp_unschedulable_until, temp_unschedulable_reason - FROM accounts - WHERE id = ANY($1) - `, pq.Array(accountIDs)) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - - for rows.Next() { - var id int64 - var until sql.NullTime - var reason sql.NullString - if err := rows.Scan(&id, &until, &reason); err != nil { - return nil, err - } - var untilPtr *time.Time - if until.Valid { - tmp := until.Time - untilPtr = &tmp - } - if reason.Valid { - out[id] = tempUnschedSnapshot{until: untilPtr, reason: reason.String} - } else { - out[id] = tempUnschedSnapshot{until: untilPtr, reason: ""} - } - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return out, nil -} - func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) { proxyMap := make(map[int64]*service.Proxy) if len(proxyIDs) == 0 { @@ -1518,31 +1534,33 @@ func accountEntityToService(m *dbent.Account) *service.Account { rateMultiplier := m.RateMultiplier return &service.Account{ - ID: m.ID, - Name: m.Name, - Notes: m.Notes, - Platform: m.Platform, - Type: m.Type, - Credentials: copyJSONMap(m.Credentials), - Extra: copyJSONMap(m.Extra), - ProxyID: m.ProxyID, - Concurrency: m.Concurrency, - Priority: m.Priority, - RateMultiplier: &rateMultiplier, - Status: m.Status, - ErrorMessage: derefString(m.ErrorMessage), - LastUsedAt: m.LastUsedAt, - ExpiresAt: m.ExpiresAt, - AutoPauseOnExpired: m.AutoPauseOnExpired, - CreatedAt: m.CreatedAt, - UpdatedAt: m.UpdatedAt, - Schedulable: m.Schedulable, - RateLimitedAt: m.RateLimitedAt, - RateLimitResetAt: m.RateLimitResetAt, - OverloadUntil: m.OverloadUntil, - SessionWindowStart: m.SessionWindowStart, - SessionWindowEnd: m.SessionWindowEnd, - SessionWindowStatus: derefString(m.SessionWindowStatus), + ID: m.ID, + Name: m.Name, + Notes: m.Notes, + Platform: m.Platform, + Type: m.Type, + Credentials: copyJSONMap(m.Credentials), + Extra: copyJSONMap(m.Extra), + ProxyID: m.ProxyID, + Concurrency: m.Concurrency, + Priority: m.Priority, + RateMultiplier: &rateMultiplier, + Status: m.Status, + ErrorMessage: derefString(m.ErrorMessage), + LastUsedAt: m.LastUsedAt, + ExpiresAt: m.ExpiresAt, + AutoPauseOnExpired: m.AutoPauseOnExpired, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + Schedulable: m.Schedulable, + RateLimitedAt: m.RateLimitedAt, + RateLimitResetAt: m.RateLimitResetAt, + OverloadUntil: m.OverloadUntil, + TempUnschedulableUntil: m.TempUnschedulableUntil, + TempUnschedulableReason: derefString(m.TempUnschedulableReason), + SessionWindowStart: m.SessionWindowStart, + SessionWindowEnd: m.SessionWindowEnd, + SessionWindowStatus: derefString(m.SessionWindowStatus), } } @@ -1578,3 +1596,64 @@ func joinClauses(clauses []string, sep string) string { func itoa(v int) string { return strconv.Itoa(v) } + +// FindByExtraField 根据 extra 字段中的键值对查找账号。 +// 该方法限定 platform='sora',避免误查询其他平台的账号。 +// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。 +// +// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。 +// +// FindByExtraField finds accounts by key-value pairs in the extra field. +// Limited to platform='sora' to avoid querying accounts from other platforms. +// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index). +// +// Use case: Finding Sora accounts linked via linked_openai_account_id. +func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) { + accounts, err := r.client.Account.Query(). + Where( + dbaccount.PlatformEQ("sora"), // 限定平台为 sora + dbaccount.DeletedAtIsNil(), + func(s *entsql.Selector) { + path := sqljson.Path(key) + switch v := value.(type) { + case string: + preds := []*entsql.Predicate{sqljson.ValueEQ(dbaccount.FieldExtra, v, path)} + if parsed, err := strconv.ParseInt(v, 10, 64); err == nil { + preds = append(preds, sqljson.ValueEQ(dbaccount.FieldExtra, parsed, path)) + } + if len(preds) == 1 { + s.Where(preds[0]) + } else { + s.Where(entsql.Or(preds...)) + } + case int: + s.Where(entsql.Or( + sqljson.ValueEQ(dbaccount.FieldExtra, v, path), + sqljson.ValueEQ(dbaccount.FieldExtra, strconv.Itoa(v), path), + )) + case int64: + s.Where(entsql.Or( + sqljson.ValueEQ(dbaccount.FieldExtra, v, path), + sqljson.ValueEQ(dbaccount.FieldExtra, strconv.FormatInt(v, 10), path), + )) + case json.Number: + if parsed, err := v.Int64(); err == nil { + s.Where(entsql.Or( + sqljson.ValueEQ(dbaccount.FieldExtra, parsed, path), + sqljson.ValueEQ(dbaccount.FieldExtra, v.String(), path), + )) + } else { + s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, v.String(), path)) + } + default: + s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, value, path)) + } + }, + ). + All(ctx) + if err != nil { + return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil) + } + + return r.accountsToService(ctx, accounts) +} diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index 4f9d0152..fd48a5d4 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -500,6 +500,38 @@ func (s *AccountRepoSuite) TestClearRateLimit() { s.Require().Nil(got.OverloadUntil) } +func (s *AccountRepoSuite) TestTempUnschedulableFieldsLoadedByGetByIDAndGetByIDs() { + acc1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-1"}) + acc2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-2"}) + + until := time.Now().Add(15 * time.Minute).UTC().Truncate(time.Second) + reason := `{"rule":"429","matched_keyword":"too many requests"}` + s.Require().NoError(s.repo.SetTempUnschedulable(s.ctx, acc1.ID, until, reason)) + + gotByID, err := s.repo.GetByID(s.ctx, acc1.ID) + s.Require().NoError(err) + s.Require().NotNil(gotByID.TempUnschedulableUntil) + s.Require().WithinDuration(until, *gotByID.TempUnschedulableUntil, time.Second) + s.Require().Equal(reason, gotByID.TempUnschedulableReason) + + gotByIDs, err := s.repo.GetByIDs(s.ctx, []int64{acc2.ID, acc1.ID}) + s.Require().NoError(err) + s.Require().Len(gotByIDs, 2) + s.Require().Equal(acc2.ID, gotByIDs[0].ID) + s.Require().Nil(gotByIDs[0].TempUnschedulableUntil) + s.Require().Equal("", gotByIDs[0].TempUnschedulableReason) + s.Require().Equal(acc1.ID, gotByIDs[1].ID) + s.Require().NotNil(gotByIDs[1].TempUnschedulableUntil) + s.Require().WithinDuration(until, *gotByIDs[1].TempUnschedulableUntil, time.Second) + s.Require().Equal(reason, gotByIDs[1].TempUnschedulableReason) + + s.Require().NoError(s.repo.ClearTempUnschedulable(s.ctx, acc1.ID)) + cleared, err := s.repo.GetByID(s.ctx, acc1.ID) + s.Require().NoError(err) + s.Require().Nil(cleared.TempUnschedulableUntil) + s.Require().Equal("", cleared.TempUnschedulableReason) +} + // --- UpdateLastUsed --- func (s *AccountRepoSuite) TestUpdateLastUsed() { diff --git a/backend/internal/repository/allowed_groups_contract_integration_test.go b/backend/internal/repository/allowed_groups_contract_integration_test.go index 0d0f11e5..b0af0d54 100644 --- a/backend/internal/repository/allowed_groups_contract_integration_test.go +++ b/backend/internal/repository/allowed_groups_contract_integration_test.go @@ -98,7 +98,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *t userRepo := newUserRepositoryWithSQL(entClient, tx) groupRepo := newGroupRepositoryWithSQL(entClient, tx) - apiKeyRepo := NewAPIKeyRepository(entClient) + apiKeyRepo := newAPIKeyRepositoryWithSQL(entClient, tx) u := &service.User{ Email: uniqueTestValue(t, "cascade-user") + "@example.com", diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 22dfa700..c761e8c9 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -2,6 +2,7 @@ package repository import ( "context" + "database/sql" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -16,10 +17,15 @@ import ( type apiKeyRepository struct { client *dbent.Client + sql sqlExecutor } -func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository { - return &apiKeyRepository{client: client} +func NewAPIKeyRepository(client *dbent.Client, sqlDB *sql.DB) service.APIKeyRepository { + return newAPIKeyRepositoryWithSQL(client, sqlDB) +} + +func newAPIKeyRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *apiKeyRepository { + return &apiKeyRepository{client: client, sql: sqlq} } func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery { @@ -34,9 +40,13 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro SetName(key.Name). SetStatus(key.Status). SetNillableGroupID(key.GroupID). + SetNillableLastUsedAt(key.LastUsedAt). SetQuota(key.Quota). SetQuotaUsed(key.QuotaUsed). - SetNillableExpiresAt(key.ExpiresAt) + SetNillableExpiresAt(key.ExpiresAt). + SetRateLimit5h(key.RateLimit5h). + SetRateLimit1d(key.RateLimit1d). + SetRateLimit7d(key.RateLimit7d) if len(key.IPWhitelist) > 0 { builder.SetIPWhitelist(key.IPWhitelist) @@ -48,6 +58,7 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro created, err := builder.Save(ctx) if err == nil { key.ID = created.ID + key.LastUsedAt = created.LastUsedAt key.CreatedAt = created.CreatedAt key.UpdatedAt = created.UpdatedAt } @@ -116,6 +127,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se apikey.FieldQuota, apikey.FieldQuotaUsed, apikey.FieldExpiresAt, + apikey.FieldRateLimit5h, + apikey.FieldRateLimit1d, + apikey.FieldRateLimit7d, ). WithUser(func(q *dbent.UserQuery) { q.Select( @@ -140,6 +154,10 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, + group.FieldSoraImagePrice360, + group.FieldSoraImagePrice540, + group.FieldSoraVideoPricePerRequest, + group.FieldSoraVideoPricePerRequestHd, group.FieldClaudeCodeOnly, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, @@ -165,13 +183,20 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro // 则会更新已删除的记录。 // 这里选择 Update().Where(),确保只有未软删除记录能被更新。 // 同时显式设置 updated_at,避免二次查询带来的并发可见性问题。 + client := clientFromContext(ctx, r.client) now := time.Now() - builder := r.client.APIKey.Update(). + builder := client.APIKey.Update(). Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()). SetName(key.Name). SetStatus(key.Status). SetQuota(key.Quota). SetQuotaUsed(key.QuotaUsed). + SetRateLimit5h(key.RateLimit5h). + SetRateLimit1d(key.RateLimit1d). + SetRateLimit7d(key.RateLimit7d). + SetUsage5h(key.Usage5h). + SetUsage1d(key.Usage1d). + SetUsage7d(key.Usage7d). SetUpdatedAt(now) if key.GroupID != nil { builder.SetGroupID(*key.GroupID) @@ -186,6 +211,23 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro builder.ClearExpiresAt() } + // Rate limit window start times + if key.Window5hStart != nil { + builder.SetWindow5hStart(*key.Window5hStart) + } else { + builder.ClearWindow5hStart() + } + if key.Window1dStart != nil { + builder.SetWindow1dStart(*key.Window1dStart) + } else { + builder.ClearWindow1dStart() + } + if key.Window7dStart != nil { + builder.SetWindow7dStart(*key.Window7dStart) + } else { + builder.ClearWindow7dStart() + } + // IP 限制字段 if len(key.IPWhitelist) > 0 { builder.SetIPWhitelist(key.IPWhitelist) @@ -239,9 +281,27 @@ func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error { return nil } -func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { +func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) { q := r.activeQuery().Where(apikey.UserIDEQ(userID)) + // Apply filters + if filters.Search != "" { + q = q.Where(apikey.Or( + apikey.NameContainsFold(filters.Search), + apikey.KeyContainsFold(filters.Search), + )) + } + if filters.Status != "" { + q = q.Where(apikey.StatusEQ(filters.Status)) + } + if filters.GroupID != nil { + if *filters.GroupID == 0 { + q = q.Where(apikey.GroupIDIsNil()) + } else { + q = q.Where(apikey.GroupIDEQ(*filters.GroupID)) + } + } + total, err := q.Count(ctx) if err != nil { return nil, nil, err @@ -375,36 +435,92 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) return keys, nil } -// IncrementQuotaUsed atomically increments the quota_used field and returns the new value +// IncrementQuotaUsed 使用 Ent 原子递增 quota_used 字段并返回新值 func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { - // Use raw SQL for atomic increment to avoid race conditions - // First get current value - m, err := r.activeQuery(). - Where(apikey.IDEQ(id)). - Select(apikey.FieldQuotaUsed). - Only(ctx) + updated, err := r.client.APIKey.UpdateOneID(id). + Where(apikey.DeletedAtIsNil()). + AddQuotaUsed(amount). + Save(ctx) if err != nil { if dbent.IsNotFound(err) { return 0, service.ErrAPIKeyNotFound } return 0, err } + return updated.QuotaUsed, nil +} - newValue := m.QuotaUsed + amount - - // Update with new value +func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { affected, err := r.client.APIKey.Update(). Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()). - SetQuotaUsed(newValue). + SetLastUsedAt(usedAt). + SetUpdatedAt(usedAt). Save(ctx) if err != nil { - return 0, err + return err } if affected == 0 { - return 0, service.ErrAPIKeyNotFound + return service.ErrAPIKeyNotFound } + return nil +} - return newValue, nil +// IncrementRateLimitUsage atomically increments all rate limit usage counters and initializes +// window start times via COALESCE if not already set. +func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error { + _, err := r.sql.ExecContext(ctx, ` + UPDATE api_keys SET + usage_5h = usage_5h + $1, + usage_1d = usage_1d + $1, + usage_7d = usage_7d + $1, + window_5h_start = COALESCE(window_5h_start, NOW()), + window_1d_start = COALESCE(window_1d_start, NOW()), + window_7d_start = COALESCE(window_7d_start, NOW()), + updated_at = NOW() + WHERE id = $2 AND deleted_at IS NULL`, + cost, id) + return err +} + +// ResetRateLimitWindows resets expired rate limit windows atomically. +func (r *apiKeyRepository) ResetRateLimitWindows(ctx context.Context, id int64) error { + _, err := r.sql.ExecContext(ctx, ` + UPDATE api_keys SET + usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN 0 ELSE usage_5h END, + window_5h_start = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END, + usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN 0 ELSE usage_1d END, + window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN NOW() ELSE window_1d_start END, + usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN 0 ELSE usage_7d END, + window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN NOW() ELSE window_7d_start END, + updated_at = NOW() + WHERE id = $1 AND deleted_at IS NULL`, + id) + return err +} + +// GetRateLimitData returns the current rate limit usage and window start times for an API key. +func (r *apiKeyRepository) GetRateLimitData(ctx context.Context, id int64) (result *service.APIKeyRateLimitData, err error) { + rows, err := r.sql.QueryContext(ctx, ` + SELECT usage_5h, usage_1d, usage_7d, window_5h_start, window_1d_start, window_7d_start + FROM api_keys + WHERE id = $1 AND deleted_at IS NULL`, + id) + if err != nil { + return nil, err + } + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + } + }() + if !rows.Next() { + return nil, service.ErrAPIKeyNotFound + } + data := &service.APIKeyRateLimitData{} + if err := rows.Scan(&data.Usage5h, &data.Usage1d, &data.Usage7d, &data.Window5hStart, &data.Window1dStart, &data.Window7dStart); err != nil { + return nil, err + } + return data, rows.Err() } func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { @@ -412,19 +528,29 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { return nil } out := &service.APIKey{ - ID: m.ID, - UserID: m.UserID, - Key: m.Key, - Name: m.Name, - Status: m.Status, - IPWhitelist: m.IPWhitelist, - IPBlacklist: m.IPBlacklist, - CreatedAt: m.CreatedAt, - UpdatedAt: m.UpdatedAt, - GroupID: m.GroupID, - Quota: m.Quota, - QuotaUsed: m.QuotaUsed, - ExpiresAt: m.ExpiresAt, + ID: m.ID, + UserID: m.UserID, + Key: m.Key, + Name: m.Name, + Status: m.Status, + IPWhitelist: m.IPWhitelist, + IPBlacklist: m.IPBlacklist, + LastUsedAt: m.LastUsedAt, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + GroupID: m.GroupID, + Quota: m.Quota, + QuotaUsed: m.QuotaUsed, + ExpiresAt: m.ExpiresAt, + RateLimit5h: m.RateLimit5h, + RateLimit1d: m.RateLimit1d, + RateLimit7d: m.RateLimit7d, + Usage5h: m.Usage5h, + Usage1d: m.Usage1d, + Usage7d: m.Usage7d, + Window5hStart: m.Window5hStart, + Window1dStart: m.Window1dStart, + Window7dStart: m.Window7dStart, } if m.Edges.User != nil { out.User = userEntityToService(m.Edges.User) @@ -440,20 +566,22 @@ func userEntityToService(u *dbent.User) *service.User { return nil } return &service.User{ - ID: u.ID, - Email: u.Email, - Username: u.Username, - Notes: u.Notes, - PasswordHash: u.PasswordHash, - Role: u.Role, - Balance: u.Balance, - Concurrency: u.Concurrency, - Status: u.Status, - TotpSecretEncrypted: u.TotpSecretEncrypted, - TotpEnabled: u.TotpEnabled, - TotpEnabledAt: u.TotpEnabledAt, - CreatedAt: u.CreatedAt, - UpdatedAt: u.UpdatedAt, + ID: u.ID, + Email: u.Email, + Username: u.Username, + Notes: u.Notes, + PasswordHash: u.PasswordHash, + Role: u.Role, + Balance: u.Balance, + Concurrency: u.Concurrency, + Status: u.Status, + SoraStorageQuotaBytes: u.SoraStorageQuotaBytes, + SoraStorageUsedBytes: u.SoraStorageUsedBytes, + TotpSecretEncrypted: u.TotpSecretEncrypted, + TotpEnabled: u.TotpEnabled, + TotpEnabledAt: u.TotpEnabledAt, + CreatedAt: u.CreatedAt, + UpdatedAt: u.UpdatedAt, } } @@ -477,6 +605,11 @@ func groupEntityToService(g *dbent.Group) *service.Group { ImagePrice1K: g.ImagePrice1k, ImagePrice2K: g.ImagePrice2k, ImagePrice4K: g.ImagePrice4k, + SoraImagePrice360: g.SoraImagePrice360, + SoraImagePrice540: g.SoraImagePrice540, + SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd, + SoraStorageQuotaBytes: g.SoraStorageQuotaBytes, DefaultValidityDays: g.DefaultValidityDays, ClaudeCodeOnly: g.ClaudeCodeOnly, FallbackGroupID: g.FallbackGroupID, diff --git a/backend/internal/repository/api_key_repo_integration_test.go b/backend/internal/repository/api_key_repo_integration_test.go index 879a0576..80714614 100644 --- a/backend/internal/repository/api_key_repo_integration_test.go +++ b/backend/internal/repository/api_key_repo_integration_test.go @@ -4,11 +4,14 @@ package repository import ( "context" + "sync" "testing" + "time" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -23,7 +26,7 @@ func (s *APIKeyRepoSuite) SetupTest() { s.ctx = context.Background() tx := testEntTx(s.T()) s.client = tx.Client() - s.repo = NewAPIKeyRepository(s.client).(*apiKeyRepository) + s.repo = newAPIKeyRepositoryWithSQL(s.client, tx) } func TestAPIKeyRepoSuite(t *testing.T) { @@ -155,7 +158,7 @@ func (s *APIKeyRepoSuite) TestListByUserID() { s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil) s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil) - keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) + keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}, service.APIKeyListFilters{}) s.Require().NoError(err, "ListByUserID") s.Require().Len(keys, 2) s.Require().Equal(int64(2), page.Total) @@ -167,7 +170,7 @@ func (s *APIKeyRepoSuite) TestListByUserID_Pagination() { s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil) } - keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2}) + keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2}, service.APIKeyListFilters{}) s.Require().NoError(err) s.Require().Len(keys, 2) s.Require().Equal(int64(5), page.Total) @@ -311,7 +314,7 @@ func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() { s.Require().Equal(service.StatusDisabled, got2.Status) s.Require().Nil(got2.GroupID) - keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}) + keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}, service.APIKeyListFilters{}) s.Require().NoError(err, "ListByUserID") s.Require().Equal(int64(1), page.Total) s.Require().Len(keys, 1) @@ -383,3 +386,87 @@ func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, group s.Require().NoError(s.repo.Create(s.ctx, k), "create api key") return k } + +// --- IncrementQuotaUsed --- + +func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_Basic() { + user := s.mustCreateUser("incr-basic@test.com") + key := s.mustCreateApiKey(user.ID, "sk-incr-basic", "Incr", nil) + + newQuota, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.5) + s.Require().NoError(err, "IncrementQuotaUsed") + s.Require().Equal(1.5, newQuota, "第一次递增后应为 1.5") + + newQuota, err = s.repo.IncrementQuotaUsed(s.ctx, key.ID, 2.5) + s.Require().NoError(err, "IncrementQuotaUsed second") + s.Require().Equal(4.0, newQuota, "第二次递增后应为 4.0") +} + +func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_NotFound() { + _, err := s.repo.IncrementQuotaUsed(s.ctx, 999999, 1.0) + s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "不存在的 key 应返回 ErrAPIKeyNotFound") +} + +func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() { + user := s.mustCreateUser("incr-deleted@test.com") + key := s.mustCreateApiKey(user.ID, "sk-incr-del", "Deleted", nil) + + s.Require().NoError(s.repo.Delete(s.ctx, key.ID), "Delete") + + _, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.0) + s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound") +} + +// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。 +// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。 +func TestIncrementQuotaUsed_Concurrent(t *testing.T) { + client := testEntClient(t) + repo := NewAPIKeyRepository(client, integrationDB).(*apiKeyRepository) + ctx := context.Background() + + // 创建测试用户和 API Key + u, err := client.User.Create(). + SetEmail("concurrent-incr-" + time.Now().Format(time.RFC3339Nano) + "@test.com"). + SetPasswordHash("hash"). + SetStatus(service.StatusActive). + SetRole(service.RoleUser). + Save(ctx) + require.NoError(t, err, "create user") + + k := &service.APIKey{ + UserID: u.ID, + Key: "sk-concurrent-" + time.Now().Format(time.RFC3339Nano), + Name: "Concurrent", + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, k), "create api key") + t.Cleanup(func() { + _ = client.APIKey.DeleteOneID(k.ID).Exec(ctx) + _ = client.User.DeleteOneID(u.ID).Exec(ctx) + }) + + // 10 个 goroutine 各递增 1.0,总计应为 10.0 + const goroutines = 10 + const increment = 1.0 + var wg sync.WaitGroup + errs := make([]error, goroutines) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + _, errs[idx] = repo.IncrementQuotaUsed(ctx, k.ID, increment) + }(i) + } + wg.Wait() + + for i, e := range errs { + require.NoError(t, e, "goroutine %d failed", i) + } + + // 验证最终结果 + got, err := repo.GetByID(ctx, k.ID) + require.NoError(t, err, "GetByID") + require.Equal(t, float64(goroutines)*increment, got.QuotaUsed, + "并发递增后总和应为 %v,实际为 %v", float64(goroutines)*increment, got.QuotaUsed) +} diff --git a/backend/internal/repository/api_key_repo_last_used_unit_test.go b/backend/internal/repository/api_key_repo_last_used_unit_test.go new file mode 100644 index 00000000..7c6e2850 --- /dev/null +++ b/backend/internal/repository/api_key_repo_last_used_unit_test.go @@ -0,0 +1,156 @@ +package repository + +import ( + "context" + "database/sql" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newAPIKeyRepoSQLite(t *testing.T) (*apiKeyRepository, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:api_key_repo_last_used?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + return &apiKeyRepository{client: client}, client +} + +func mustCreateAPIKeyRepoUser(t *testing.T, ctx context.Context, client *dbent.Client, email string) *service.User { + t.Helper() + u, err := client.User.Create(). + SetEmail(email). + SetPasswordHash("test-password-hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + return userEntityToService(u) +} + +func TestAPIKeyRepository_CreateWithLastUsedAt(t *testing.T) { + repo, client := newAPIKeyRepoSQLite(t) + ctx := context.Background() + user := mustCreateAPIKeyRepoUser(t, ctx, client, "create-last-used@test.com") + + lastUsed := time.Now().UTC().Add(-time.Hour).Truncate(time.Second) + key := &service.APIKey{ + UserID: user.ID, + Key: "sk-create-last-used", + Name: "CreateWithLastUsed", + Status: service.StatusActive, + LastUsedAt: &lastUsed, + } + + require.NoError(t, repo.Create(ctx, key)) + require.NotNil(t, key.LastUsedAt) + require.WithinDuration(t, lastUsed, *key.LastUsedAt, time.Second) + + got, err := repo.GetByID(ctx, key.ID) + require.NoError(t, err) + require.NotNil(t, got.LastUsedAt) + require.WithinDuration(t, lastUsed, *got.LastUsedAt, time.Second) +} + +func TestAPIKeyRepository_UpdateLastUsed(t *testing.T) { + repo, client := newAPIKeyRepoSQLite(t) + ctx := context.Background() + user := mustCreateAPIKeyRepoUser(t, ctx, client, "update-last-used@test.com") + + key := &service.APIKey{ + UserID: user.ID, + Key: "sk-update-last-used", + Name: "UpdateLastUsed", + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, key)) + + before, err := repo.GetByID(ctx, key.ID) + require.NoError(t, err) + require.Nil(t, before.LastUsedAt) + + target := time.Now().UTC().Add(2 * time.Minute).Truncate(time.Second) + require.NoError(t, repo.UpdateLastUsed(ctx, key.ID, target)) + + after, err := repo.GetByID(ctx, key.ID) + require.NoError(t, err) + require.NotNil(t, after.LastUsedAt) + require.WithinDuration(t, target, *after.LastUsedAt, time.Second) + require.WithinDuration(t, target, after.UpdatedAt, time.Second) +} + +func TestAPIKeyRepository_UpdateLastUsedDeletedKey(t *testing.T) { + repo, client := newAPIKeyRepoSQLite(t) + ctx := context.Background() + user := mustCreateAPIKeyRepoUser(t, ctx, client, "deleted-last-used@test.com") + + key := &service.APIKey{ + UserID: user.ID, + Key: "sk-update-last-used-deleted", + Name: "UpdateLastUsedDeleted", + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, key)) + require.NoError(t, repo.Delete(ctx, key.ID)) + + err := repo.UpdateLastUsed(ctx, key.ID, time.Now().UTC()) + require.ErrorIs(t, err, service.ErrAPIKeyNotFound) +} + +func TestAPIKeyRepository_UpdateLastUsedDBError(t *testing.T) { + repo, client := newAPIKeyRepoSQLite(t) + ctx := context.Background() + user := mustCreateAPIKeyRepoUser(t, ctx, client, "db-error-last-used@test.com") + + key := &service.APIKey{ + UserID: user.ID, + Key: "sk-update-last-used-db-error", + Name: "UpdateLastUsedDBError", + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, key)) + + require.NoError(t, client.Close()) + err := repo.UpdateLastUsed(ctx, key.ID, time.Now().UTC()) + require.Error(t, err) +} + +func TestAPIKeyRepository_CreateDuplicateKey(t *testing.T) { + repo, client := newAPIKeyRepoSQLite(t) + ctx := context.Background() + user := mustCreateAPIKeyRepoUser(t, ctx, client, "duplicate-key@test.com") + + first := &service.APIKey{ + UserID: user.ID, + Key: "sk-duplicate", + Name: "first", + Status: service.StatusActive, + } + second := &service.APIKey{ + UserID: user.ID, + Key: "sk-duplicate", + Name: "second", + Status: service.StatusActive, + } + + require.NoError(t, repo.Create(ctx, first)) + err := repo.Create(ctx, second) + require.ErrorIs(t, err, service.ErrAPIKeyExists) +} diff --git a/backend/internal/repository/billing_cache.go b/backend/internal/repository/billing_cache.go index ac5803a1..4fbdae14 100644 --- a/backend/internal/repository/billing_cache.go +++ b/backend/internal/repository/billing_cache.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log" + "math/rand/v2" "strconv" "time" @@ -13,11 +14,24 @@ import ( ) const ( - billingBalanceKeyPrefix = "billing:balance:" - billingSubKeyPrefix = "billing:sub:" - billingCacheTTL = 5 * time.Minute + billingBalanceKeyPrefix = "billing:balance:" + billingSubKeyPrefix = "billing:sub:" + billingRateLimitKeyPrefix = "apikey:rate:" + billingCacheTTL = 5 * time.Minute + billingCacheJitter = 30 * time.Second + rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window ) +// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩 +func jitteredTTL() time.Duration { + // 只做“减法抖动”,确保实际 TTL 不会超过 billingCacheTTL(避免上界预期被打破)。 + if billingCacheJitter <= 0 { + return billingCacheTTL + } + jitter := time.Duration(rand.IntN(int(billingCacheJitter))) + return billingCacheTTL - jitter +} + // billingBalanceKey generates the Redis key for user balance cache. func billingBalanceKey(userID int64) string { return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) @@ -37,6 +51,20 @@ const ( subFieldVersion = "version" ) +// billingRateLimitKey generates the Redis key for API key rate limit cache. +func billingRateLimitKey(keyID int64) string { + return fmt.Sprintf("%s%d", billingRateLimitKeyPrefix, keyID) +} + +const ( + rateLimitFieldUsage5h = "usage_5h" + rateLimitFieldUsage1d = "usage_1d" + rateLimitFieldUsage7d = "usage_7d" + rateLimitFieldWindow5h = "window_5h" + rateLimitFieldWindow1d = "window_1d" + rateLimitFieldWindow7d = "window_7d" +) + var ( deductBalanceScript = redis.NewScript(` local current = redis.call('GET', KEYS[1]) @@ -61,6 +89,21 @@ var ( redis.call('EXPIRE', KEYS[1], ARGV[2]) return 1 `) + + // updateRateLimitUsageScript atomically increments all three rate limit usage counters. + // Returns 0 if the key doesn't exist (cache miss), 1 on success. + updateRateLimitUsageScript = redis.NewScript(` + local exists = redis.call('EXISTS', KEYS[1]) + if exists == 0 then + return 0 + end + local cost = tonumber(ARGV[1]) + redis.call('HINCRBYFLOAT', KEYS[1], 'usage_5h', cost) + redis.call('HINCRBYFLOAT', KEYS[1], 'usage_1d', cost) + redis.call('HINCRBYFLOAT', KEYS[1], 'usage_7d', cost) + redis.call('EXPIRE', KEYS[1], ARGV[2]) + return 1 + `) ) type billingCache struct { @@ -82,14 +125,15 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6 func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error { key := billingBalanceKey(userID) - return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err() + return c.rdb.Set(ctx, key, balance, jitteredTTL()).Err() } func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { key := billingBalanceKey(userID) - _, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result() + _, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result() if err != nil && !errors.Is(err, redis.Nil) { log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err) + return err } return nil } @@ -163,16 +207,17 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID pipe := c.rdb.Pipeline() pipe.HSet(ctx, key, fields) - pipe.Expire(ctx, key, billingCacheTTL) + pipe.Expire(ctx, key, jitteredTTL()) _, err := pipe.Exec(ctx) return err } func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error { key := billingSubKey(userID, groupID) - _, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result() + _, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(jitteredTTL().Seconds())).Result() if err != nil && !errors.Is(err, redis.Nil) { log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err) + return err } return nil } @@ -181,3 +226,69 @@ func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, key := billingSubKey(userID, groupID) return c.rdb.Del(ctx, key).Err() } + +func (c *billingCache) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*service.APIKeyRateLimitCacheData, error) { + key := billingRateLimitKey(keyID) + result, err := c.rdb.HGetAll(ctx, key).Result() + if err != nil { + return nil, err + } + if len(result) == 0 { + return nil, redis.Nil + } + data := &service.APIKeyRateLimitCacheData{} + if v, ok := result[rateLimitFieldUsage5h]; ok { + data.Usage5h, _ = strconv.ParseFloat(v, 64) + } + if v, ok := result[rateLimitFieldUsage1d]; ok { + data.Usage1d, _ = strconv.ParseFloat(v, 64) + } + if v, ok := result[rateLimitFieldUsage7d]; ok { + data.Usage7d, _ = strconv.ParseFloat(v, 64) + } + if v, ok := result[rateLimitFieldWindow5h]; ok { + data.Window5h, _ = strconv.ParseInt(v, 10, 64) + } + if v, ok := result[rateLimitFieldWindow1d]; ok { + data.Window1d, _ = strconv.ParseInt(v, 10, 64) + } + if v, ok := result[rateLimitFieldWindow7d]; ok { + data.Window7d, _ = strconv.ParseInt(v, 10, 64) + } + return data, nil +} + +func (c *billingCache) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *service.APIKeyRateLimitCacheData) error { + if data == nil { + return nil + } + key := billingRateLimitKey(keyID) + fields := map[string]any{ + rateLimitFieldUsage5h: data.Usage5h, + rateLimitFieldUsage1d: data.Usage1d, + rateLimitFieldUsage7d: data.Usage7d, + rateLimitFieldWindow5h: data.Window5h, + rateLimitFieldWindow1d: data.Window1d, + rateLimitFieldWindow7d: data.Window7d, + } + pipe := c.rdb.Pipeline() + pipe.HSet(ctx, key, fields) + pipe.Expire(ctx, key, rateLimitCacheTTL) + _, err := pipe.Exec(ctx) + return err +} + +func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error { + key := billingRateLimitKey(keyID) + _, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(rateLimitCacheTTL.Seconds())).Result() + if err != nil && !errors.Is(err, redis.Nil) { + log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err) + return err + } + return nil +} + +func (c *billingCache) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error { + key := billingRateLimitKey(keyID) + return c.rdb.Del(ctx, key).Err() +} diff --git a/backend/internal/repository/billing_cache_integration_test.go b/backend/internal/repository/billing_cache_integration_test.go index 2f7c69a7..4b7377b1 100644 --- a/backend/internal/repository/billing_cache_integration_test.go +++ b/backend/internal/repository/billing_cache_integration_test.go @@ -278,6 +278,90 @@ func (s *BillingCacheSuite) TestSubscriptionCache() { } } +// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复: +// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。 +func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() { + tests := []struct { + name string + fn func(ctx context.Context, cache service.BillingCache) + expectErr bool + }{ + { + name: "key_not_exists_returns_nil", + fn: func(ctx context.Context, cache service.BillingCache) { + // key 不存在时,Lua 脚本返回 0(redis.Nil),应返回 nil 而非错误 + err := cache.DeductUserBalance(ctx, 99999, 1.0) + require.NoError(s.T(), err, "DeductUserBalance on non-existent key should return nil") + }, + }, + { + name: "existing_key_deducts_successfully", + fn: func(ctx context.Context, cache service.BillingCache) { + require.NoError(s.T(), cache.SetUserBalance(ctx, 200, 50.0)) + err := cache.DeductUserBalance(ctx, 200, 10.0) + require.NoError(s.T(), err, "DeductUserBalance should succeed") + + bal, err := cache.GetUserBalance(ctx, 200) + require.NoError(s.T(), err) + require.Equal(s.T(), 40.0, bal, "余额应为 40.0") + }, + }, + { + name: "cancelled_context_propagates_error", + fn: func(ctx context.Context, cache service.BillingCache) { + require.NoError(s.T(), cache.SetUserBalance(ctx, 201, 50.0)) + + cancelCtx, cancel := context.WithCancel(ctx) + cancel() // 立即取消 + + err := cache.DeductUserBalance(cancelCtx, 201, 10.0) + require.Error(s.T(), err, "cancelled context should propagate error") + }, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + rdb := testRedis(s.T()) + cache := NewBillingCache(rdb) + ctx := context.Background() + tt.fn(ctx, cache) + }) + } +} + +// TestUpdateSubscriptionUsage_ErrorPropagation 验证 P2-12 修复: +// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。 +func (s *BillingCacheSuite) TestUpdateSubscriptionUsage_ErrorPropagation() { + s.Run("key_not_exists_returns_nil", func() { + rdb := testRedis(s.T()) + cache := NewBillingCache(rdb) + ctx := context.Background() + + err := cache.UpdateSubscriptionUsage(ctx, 88888, 77777, 1.0) + require.NoError(s.T(), err, "UpdateSubscriptionUsage on non-existent key should return nil") + }) + + s.Run("cancelled_context_propagates_error", func() { + rdb := testRedis(s.T()) + cache := NewBillingCache(rdb) + ctx := context.Background() + + data := &service.SubscriptionCacheData{ + Status: "active", + ExpiresAt: time.Now().Add(1 * time.Hour), + Version: 1, + } + require.NoError(s.T(), cache.SetSubscriptionCache(ctx, 301, 401, data)) + + cancelCtx, cancel := context.WithCancel(ctx) + cancel() + + err := cache.UpdateSubscriptionUsage(cancelCtx, 301, 401, 1.0) + require.Error(s.T(), err, "cancelled context should propagate error") + }) +} + func TestBillingCacheSuite(t *testing.T) { suite.Run(t, new(BillingCacheSuite)) } diff --git a/backend/internal/repository/billing_cache_jitter_test.go b/backend/internal/repository/billing_cache_jitter_test.go new file mode 100644 index 00000000..ba4f2873 --- /dev/null +++ b/backend/internal/repository/billing_cache_jitter_test.go @@ -0,0 +1,82 @@ +package repository + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- Task 6.1 验证: math/rand/v2 迁移后 jitteredTTL 行为正确 --- + +func TestJitteredTTL_WithinExpectedRange(t *testing.T) { + // jitteredTTL 使用减法抖动: billingCacheTTL - [0, billingCacheJitter) + // 所以结果应在 [billingCacheTTL - billingCacheJitter, billingCacheTTL] 范围内 + lowerBound := billingCacheTTL - billingCacheJitter // 5min - 30s = 4min30s + upperBound := billingCacheTTL // 5min + + for i := 0; i < 200; i++ { + ttl := jitteredTTL() + assert.GreaterOrEqual(t, int64(ttl), int64(lowerBound), + "TTL 不应低于 %v,实际得到 %v", lowerBound, ttl) + assert.LessOrEqual(t, int64(ttl), int64(upperBound), + "TTL 不应超过 %v(上界不变保证),实际得到 %v", upperBound, ttl) + } +} + +func TestJitteredTTL_NeverExceedsBase(t *testing.T) { + // 关键安全性测试:jitteredTTL 使用减法抖动,确保永远不超过 billingCacheTTL + for i := 0; i < 500; i++ { + ttl := jitteredTTL() + assert.LessOrEqual(t, int64(ttl), int64(billingCacheTTL), + "jitteredTTL 不应超过基础 TTL(上界预期不被打破)") + } +} + +func TestJitteredTTL_HasVariance(t *testing.T) { + // 验证抖动确实产生了不同的值 + results := make(map[time.Duration]bool) + for i := 0; i < 100; i++ { + ttl := jitteredTTL() + results[ttl] = true + } + + require.Greater(t, len(results), 1, + "jitteredTTL 应产生不同的值(抖动生效),但 100 次调用结果全部相同") +} + +func TestJitteredTTL_AverageNearCenter(t *testing.T) { + // 验证平均值大约在抖动范围中间 + var sum time.Duration + runs := 1000 + for i := 0; i < runs; i++ { + sum += jitteredTTL() + } + + avg := sum / time.Duration(runs) + expectedCenter := billingCacheTTL - billingCacheJitter/2 // 4min45s + + // 允许 ±5s 的误差 + tolerance := 5 * time.Second + assert.InDelta(t, float64(expectedCenter), float64(avg), float64(tolerance), + "平均 TTL 应接近抖动范围中心 %v", expectedCenter) +} + +func TestBillingKeyGeneration(t *testing.T) { + t.Run("balance_key", func(t *testing.T) { + key := billingBalanceKey(12345) + assert.Equal(t, "billing:balance:12345", key) + }) + + t.Run("sub_key", func(t *testing.T) { + key := billingSubKey(100, 200) + assert.Equal(t, "billing:sub:100:200", key) + }) +} + +func BenchmarkJitteredTTL(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = jitteredTTL() + } +} diff --git a/backend/internal/repository/billing_cache_test.go b/backend/internal/repository/billing_cache_test.go index 7d3fd19d..2de1da87 100644 --- a/backend/internal/repository/billing_cache_test.go +++ b/backend/internal/repository/billing_cache_test.go @@ -5,6 +5,7 @@ package repository import ( "math" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -85,3 +86,26 @@ func TestBillingSubKey(t *testing.T) { }) } } + +func TestJitteredTTL(t *testing.T) { + const ( + minTTL = 4*time.Minute + 30*time.Second // 270s = 5min - 30s + maxTTL = 5*time.Minute + 30*time.Second // 330s = 5min + 30s + ) + + for i := 0; i < 200; i++ { + ttl := jitteredTTL() + require.GreaterOrEqual(t, ttl, minTTL, "jitteredTTL() 返回值低于下限: %v", ttl) + require.LessOrEqual(t, ttl, maxTTL, "jitteredTTL() 返回值超过上限: %v", ttl) + } +} + +func TestJitteredTTL_HasVariation(t *testing.T) { + // 多次调用应该产生不同的值(验证抖动存在) + seen := make(map[time.Duration]struct{}, 50) + for i := 0; i < 50; i++ { + seen[jitteredTTL()] = struct{}{} + } + // 50 次调用中应该至少有 2 个不同的值 + require.Greater(t, len(seen), 1, "jitteredTTL() 应产生不同的 TTL 值") +} diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go index fc0d2918..b754bd55 100644 --- a/backend/internal/repository/claude_oauth_service.go +++ b/backend/internal/repository/claude_oauth_service.go @@ -4,13 +4,14 @@ import ( "context" "encoding/json" "fmt" - "log" "net/http" "net/url" "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/util/logredact" @@ -28,11 +29,14 @@ func NewClaudeOAuthClient() service.ClaudeOAuthClient { type claudeOAuthService struct { baseURL string tokenURL string - clientFactory func(proxyURL string) *req.Client + clientFactory func(proxyURL string) (*req.Client, error) } func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) { - client := s.clientFactory(proxyURL) + client, err := s.clientFactory(proxyURL) + if err != nil { + return "", fmt.Errorf("create HTTP client: %w", err) + } var orgs []struct { UUID string `json:"uuid"` @@ -41,7 +45,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey } targetURL := s.baseURL + "/api/organizations" - log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1: Getting organization UUID from %s", targetURL) resp, err := client.R(). SetContext(ctx). @@ -53,11 +57,11 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey Get(targetURL) if err != nil { - log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 FAILED - Request error: %v", err) return "", fmt.Errorf("request failed: %w", err) } - log.Printf("[OAuth] Step 1 Response - Status: %d", resp.StatusCode) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 Response - Status: %d", resp.StatusCode) if !resp.IsSuccessState() { return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String()) @@ -69,26 +73,29 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey // 如果只有一个组织,直接使用 if len(orgs) == 1 { - log.Printf("[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name) return orgs[0].UUID, nil } // 如果有多个组织,优先选择 raven_type 为 "team" 的组织 for _, org := range orgs { if org.RavenType != nil && *org.RavenType == "team" { - log.Printf("[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s", + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s", org.UUID, org.Name, *org.RavenType) return org.UUID, nil } } // 如果没有 team 类型的组织,使用第一个 - log.Printf("[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name) return orgs[0].UUID, nil } func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) { - client := s.clientFactory(proxyURL) + client, err := s.clientFactory(proxyURL) + if err != nil { + return "", fmt.Errorf("create HTTP client: %w", err) + } authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID) @@ -103,9 +110,9 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe "code_challenge_method": "S256", } - log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2: Getting authorization code from %s", authURL) reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody)) - log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON)) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 Request Body: %s", string(reqBodyJSON)) var result struct { RedirectURI string `json:"redirect_uri"` @@ -128,11 +135,11 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe Post(authURL) if err != nil { - log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 FAILED - Request error: %v", err) return "", fmt.Errorf("request failed: %w", err) } - log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes())) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes())) if !resp.IsSuccessState() { return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String()) @@ -160,12 +167,15 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe fullCode = authCode + "#" + responseState } - log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code") + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 SUCCESS - Got authorization code") return fullCode, nil } func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { - client := s.clientFactory(proxyURL) + client, err := s.clientFactory(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } // Parse code which may contain state in format "authCode#state" authCode := code @@ -192,9 +202,9 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds } - log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL) reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody)) - log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON)) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 Request Body: %s", string(reqBodyJSON)) var tokenResp oauth.TokenResponse @@ -208,22 +218,25 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod Post(s.tokenURL) if err != nil { - log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 FAILED - Request error: %v", err) return nil, fmt.Errorf("request failed: %w", err) } - log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes())) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes())) if !resp.IsSuccessState() { return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String()) } - log.Printf("[OAuth] Step 3 SUCCESS - Got access token") + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 SUCCESS - Got access token") return &tokenResp, nil } func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { - client := s.clientFactory(proxyURL) + client, err := s.clientFactory(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } reqBody := map[string]any{ "grant_type": "refresh_token", @@ -253,16 +266,20 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro return &tokenResp, nil } -func createReqClient(proxyURL string) *req.Client { +func createReqClient(proxyURL string) (*req.Client, error) { // 禁用 CookieJar,确保每次授权都是干净的会话 client := req.C(). SetTimeout(60 * time.Second). ImpersonateChrome(). SetCookieJar(nil) // 禁用 CookieJar - if strings.TrimSpace(proxyURL) != "" { - client.SetProxyURL(strings.TrimSpace(proxyURL)) + trimmed, _, err := proxyurl.Parse(proxyURL) + if err != nil { + return nil, err + } + if trimmed != "" { + client.SetProxyURL(trimmed) } - return client + return client, nil } diff --git a/backend/internal/repository/claude_oauth_service_test.go b/backend/internal/repository/claude_oauth_service_test.go index 7395c6d8..c6383033 100644 --- a/backend/internal/repository/claude_oauth_service_test.go +++ b/backend/internal/repository/claude_oauth_service_test.go @@ -91,7 +91,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.baseURL = "http://in-process" - s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } + s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "") @@ -169,7 +169,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.baseURL = "http://in-process" - s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } + s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeInference, "cc", "st", "") @@ -276,7 +276,7 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.tokenURL = "http://in-process/token" - s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } + s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken) @@ -372,7 +372,7 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.tokenURL = "http://in-process/token" - s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } + s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } resp, err := s.client.RefreshToken(context.Background(), "rt", "") diff --git a/backend/internal/repository/claude_usage_service.go b/backend/internal/repository/claude_usage_service.go index 1198f472..f6054828 100644 --- a/backend/internal/repository/claude_usage_service.go +++ b/backend/internal/repository/claude_usage_service.go @@ -83,7 +83,7 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se AllowPrivateHosts: s.allowPrivateHosts, }) if err != nil { - client = &http.Client{Timeout: 30 * time.Second} + return nil, fmt.Errorf("create http client failed: %w", err) } resp, err = client.Do(req) diff --git a/backend/internal/repository/claude_usage_service_test.go b/backend/internal/repository/claude_usage_service_test.go index 2e10f3e5..cbd0b6d3 100644 --- a/backend/internal/repository/claude_usage_service_test.go +++ b/backend/internal/repository/claude_usage_service_test.go @@ -50,7 +50,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() { allowPrivateHosts: true, } - resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url") + resp, err := s.fetcher.FetchUsage(context.Background(), "at", "") require.NoError(s.T(), err, "FetchUsage") require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch") require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch") @@ -112,6 +112,17 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() { require.Error(s.T(), err, "expected error for cancelled context") } +func (s *ClaudeUsageServiceSuite) TestFetchUsage_InvalidProxyReturnsError() { + s.fetcher = &claudeUsageService{ + usageURL: "http://example.com", + allowPrivateHosts: true, + } + + _, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url") + require.Error(s.T(), err) + require.ErrorContains(s.T(), err, "create http client failed") +} + func TestClaudeUsageServiceSuite(t *testing.T) { suite.Run(t, new(ClaudeUsageServiceSuite)) } diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index cc0c6db5..a2552715 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -147,100 +147,6 @@ var ( return 1 `) - // getAccountsLoadBatchScript - batch load query with expired slot cleanup - // ARGV[1] = slot TTL (seconds) - // ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ... - getAccountsLoadBatchScript = redis.NewScript(` - local result = {} - local slotTTL = tonumber(ARGV[1]) - - -- Get current server time - local timeResult = redis.call('TIME') - local nowSeconds = tonumber(timeResult[1]) - local cutoffTime = nowSeconds - slotTTL - - local i = 2 - while i <= #ARGV do - local accountID = ARGV[i] - local maxConcurrency = tonumber(ARGV[i + 1]) - - local slotKey = 'concurrency:account:' .. accountID - - -- Clean up expired slots before counting - redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime) - local currentConcurrency = redis.call('ZCARD', slotKey) - - local waitKey = 'wait:account:' .. accountID - local waitingCount = redis.call('GET', waitKey) - if waitingCount == false then - waitingCount = 0 - else - waitingCount = tonumber(waitingCount) - end - - local loadRate = 0 - if maxConcurrency > 0 then - loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency) - end - - table.insert(result, accountID) - table.insert(result, currentConcurrency) - table.insert(result, waitingCount) - table.insert(result, loadRate) - - i = i + 2 - end - - return result - `) - - // getUsersLoadBatchScript - batch load query for users with expired slot cleanup - // ARGV[1] = slot TTL (seconds) - // ARGV[2..n] = userID1, maxConcurrency1, userID2, maxConcurrency2, ... - getUsersLoadBatchScript = redis.NewScript(` - local result = {} - local slotTTL = tonumber(ARGV[1]) - - -- Get current server time - local timeResult = redis.call('TIME') - local nowSeconds = tonumber(timeResult[1]) - local cutoffTime = nowSeconds - slotTTL - - local i = 2 - while i <= #ARGV do - local userID = ARGV[i] - local maxConcurrency = tonumber(ARGV[i + 1]) - - local slotKey = 'concurrency:user:' .. userID - - -- Clean up expired slots before counting - redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime) - local currentConcurrency = redis.call('ZCARD', slotKey) - - local waitKey = 'concurrency:wait:' .. userID - local waitingCount = redis.call('GET', waitKey) - if waitingCount == false then - waitingCount = 0 - else - waitingCount = tonumber(waitingCount) - end - - local loadRate = 0 - if maxConcurrency > 0 then - loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency) - end - - table.insert(result, userID) - table.insert(result, currentConcurrency) - table.insert(result, waitingCount) - table.insert(result, loadRate) - - i = i + 2 - end - - return result - `) - // cleanupExpiredSlotsScript - remove expired slots // KEYS[1] = concurrency:account:{accountID} // ARGV[1] = TTL (seconds) @@ -321,6 +227,43 @@ func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID return result, nil } +func (c *concurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { + if len(accountIDs) == 0 { + return map[int64]int{}, nil + } + + now, err := c.rdb.Time(ctx).Result() + if err != nil { + return nil, fmt.Errorf("redis TIME: %w", err) + } + cutoffTime := now.Unix() - int64(c.slotTTLSeconds) + + pipe := c.rdb.Pipeline() + type accountCmd struct { + accountID int64 + zcardCmd *redis.IntCmd + } + cmds := make([]accountCmd, 0, len(accountIDs)) + for _, accountID := range accountIDs { + slotKey := accountSlotKeyPrefix + strconv.FormatInt(accountID, 10) + pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10)) + cmds = append(cmds, accountCmd{ + accountID: accountID, + zcardCmd: pipe.ZCard(ctx, slotKey), + }) + } + + if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("pipeline exec: %w", err) + } + + result := make(map[int64]int, len(accountIDs)) + for _, cmd := range cmds { + result[cmd.accountID] = int(cmd.zcardCmd.Val()) + } + return result, nil +} + // User slot operations func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { @@ -399,29 +342,53 @@ func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts [] return map[int64]*service.AccountLoadInfo{}, nil } - args := []any{c.slotTTLSeconds} - for _, acc := range accounts { - args = append(args, acc.ID, acc.MaxConcurrency) - } - - result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice() + // 使用 Pipeline 替代 Lua 脚本,兼容 Redis Cluster(Lua 内动态拼 key 会 CROSSSLOT)。 + // 每个账号执行 3 个命令:ZREMRANGEBYSCORE(清理过期)、ZCARD(并发数)、GET(等待数)。 + now, err := c.rdb.Time(ctx).Result() if err != nil { - return nil, err + return nil, fmt.Errorf("redis TIME: %w", err) + } + cutoffTime := now.Unix() - int64(c.slotTTLSeconds) + + pipe := c.rdb.Pipeline() + + type accountCmds struct { + id int64 + maxConcurrency int + zcardCmd *redis.IntCmd + getCmd *redis.StringCmd + } + cmds := make([]accountCmds, 0, len(accounts)) + for _, acc := range accounts { + slotKey := accountSlotKeyPrefix + strconv.FormatInt(acc.ID, 10) + waitKey := accountWaitKeyPrefix + strconv.FormatInt(acc.ID, 10) + pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10)) + ac := accountCmds{ + id: acc.ID, + maxConcurrency: acc.MaxConcurrency, + zcardCmd: pipe.ZCard(ctx, slotKey), + getCmd: pipe.Get(ctx, waitKey), + } + cmds = append(cmds, ac) } - loadMap := make(map[int64]*service.AccountLoadInfo) - for i := 0; i < len(result); i += 4 { - if i+3 >= len(result) { - break + if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("pipeline exec: %w", err) + } + + loadMap := make(map[int64]*service.AccountLoadInfo, len(accounts)) + for _, ac := range cmds { + currentConcurrency := int(ac.zcardCmd.Val()) + waitingCount := 0 + if v, err := ac.getCmd.Int(); err == nil { + waitingCount = v } - - accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64) - currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1])) - waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2])) - loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3])) - - loadMap[accountID] = &service.AccountLoadInfo{ - AccountID: accountID, + loadRate := 0 + if ac.maxConcurrency > 0 { + loadRate = (currentConcurrency + waitingCount) * 100 / ac.maxConcurrency + } + loadMap[ac.id] = &service.AccountLoadInfo{ + AccountID: ac.id, CurrentConcurrency: currentConcurrency, WaitingCount: waitingCount, LoadRate: loadRate, @@ -436,29 +403,52 @@ func (c *concurrencyCache) GetUsersLoadBatch(ctx context.Context, users []servic return map[int64]*service.UserLoadInfo{}, nil } - args := []any{c.slotTTLSeconds} - for _, u := range users { - args = append(args, u.ID, u.MaxConcurrency) - } - - result, err := getUsersLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice() + // 使用 Pipeline 替代 Lua 脚本,兼容 Redis Cluster。 + now, err := c.rdb.Time(ctx).Result() if err != nil { - return nil, err + return nil, fmt.Errorf("redis TIME: %w", err) + } + cutoffTime := now.Unix() - int64(c.slotTTLSeconds) + + pipe := c.rdb.Pipeline() + + type userCmds struct { + id int64 + maxConcurrency int + zcardCmd *redis.IntCmd + getCmd *redis.StringCmd + } + cmds := make([]userCmds, 0, len(users)) + for _, u := range users { + slotKey := userSlotKeyPrefix + strconv.FormatInt(u.ID, 10) + waitKey := waitQueueKeyPrefix + strconv.FormatInt(u.ID, 10) + pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10)) + uc := userCmds{ + id: u.ID, + maxConcurrency: u.MaxConcurrency, + zcardCmd: pipe.ZCard(ctx, slotKey), + getCmd: pipe.Get(ctx, waitKey), + } + cmds = append(cmds, uc) } - loadMap := make(map[int64]*service.UserLoadInfo) - for i := 0; i < len(result); i += 4 { - if i+3 >= len(result) { - break + if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("pipeline exec: %w", err) + } + + loadMap := make(map[int64]*service.UserLoadInfo, len(users)) + for _, uc := range cmds { + currentConcurrency := int(uc.zcardCmd.Val()) + waitingCount := 0 + if v, err := uc.getCmd.Int(); err == nil { + waitingCount = v } - - userID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64) - currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1])) - waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2])) - loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3])) - - loadMap[userID] = &service.UserLoadInfo{ - UserID: userID, + loadRate := 0 + if uc.maxConcurrency > 0 { + loadRate = (currentConcurrency + waitingCount) * 100 / uc.maxConcurrency + } + loadMap[uc.id] = &service.UserLoadInfo{ + UserID: uc.id, CurrentConcurrency: currentConcurrency, WaitingCount: waitingCount, LoadRate: loadRate, diff --git a/backend/internal/repository/ent.go b/backend/internal/repository/ent.go index d7d574e8..5f3f5a84 100644 --- a/backend/internal/repository/ent.go +++ b/backend/internal/repository/ent.go @@ -5,6 +5,7 @@ package repository import ( "context" "database/sql" + "fmt" "time" "github.com/Wei-Shaw/sub2api/ent" @@ -66,6 +67,18 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) { // 创建 Ent 客户端,绑定到已配置的数据库驱动。 client := ent.NewClient(ent.Driver(drv)) + // 启动阶段:从配置或数据库中确保系统密钥可用。 + if err := ensureBootstrapSecrets(migrationCtx, client, cfg); err != nil { + _ = client.Close() + return nil, nil, err + } + + // 在密钥补齐后执行完整配置校验,避免空 jwt.secret 导致服务运行时失败。 + if err := cfg.Validate(); err != nil { + _ = client.Close() + return nil, nil, fmt.Errorf("validate config after secret bootstrap: %w", err) + } + // SIMPLE 模式:启动时补齐各平台默认分组。 // - anthropic/openai/gemini: 确保存在 -default // - antigravity: 仅要求存在 >=2 个未软删除分组(用于 claude/gemini 混合调度场景) diff --git a/backend/internal/repository/gateway_cache_integration_test.go b/backend/internal/repository/gateway_cache_integration_test.go index 2fdaa3d1..0eebc33f 100644 --- a/backend/internal/repository/gateway_cache_integration_test.go +++ b/backend/internal/repository/gateway_cache_integration_test.go @@ -104,7 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() { require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil") } - func TestGatewayCacheSuite(t *testing.T) { suite.Run(t, new(GatewayCacheSuite)) } diff --git a/backend/internal/repository/gemini_drive_client.go b/backend/internal/repository/gemini_drive_client.go new file mode 100644 index 00000000..2e383595 --- /dev/null +++ b/backend/internal/repository/gemini_drive_client.go @@ -0,0 +1,9 @@ +package repository + +import "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + +// NewGeminiDriveClient creates a concrete DriveClient for Google Drive API operations. +// Returned as geminicli.DriveClient interface for DI (Strategy A). +func NewGeminiDriveClient() geminicli.DriveClient { + return geminicli.NewDriveClient() +} diff --git a/backend/internal/repository/gemini_oauth_client.go b/backend/internal/repository/gemini_oauth_client.go index 8b7fe625..eb14f313 100644 --- a/backend/internal/repository/gemini_oauth_client.go +++ b/backend/internal/repository/gemini_oauth_client.go @@ -26,7 +26,10 @@ func NewGeminiOAuthClient(cfg *config.Config) service.GeminiOAuthClient { } func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) { - client := createGeminiReqClient(proxyURL) + client, err := createGeminiReqClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } // Use different OAuth clients based on oauthType: // - code_assist: always use built-in Gemini CLI OAuth client (public) @@ -72,7 +75,10 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c } func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { - client := createGeminiReqClient(proxyURL) + client, err := createGeminiReqClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } oauthCfgInput := geminicli.OAuthConfig{ ClientID: c.cfg.Gemini.OAuth.ClientID, @@ -111,7 +117,7 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh return &tokenResp, nil } -func createGeminiReqClient(proxyURL string) *req.Client { +func createGeminiReqClient(proxyURL string) (*req.Client, error) { return getSharedReqClient(reqClientOptions{ ProxyURL: proxyURL, Timeout: 60 * time.Second, diff --git a/backend/internal/repository/geminicli_codeassist_client.go b/backend/internal/repository/geminicli_codeassist_client.go index 4f63280d..b5bc6497 100644 --- a/backend/internal/repository/geminicli_codeassist_client.go +++ b/backend/internal/repository/geminicli_codeassist_client.go @@ -26,7 +26,11 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo } var out geminicli.LoadCodeAssistResponse - resp, err := createGeminiCliReqClient(proxyURL).R(). + client, err := createGeminiCliReqClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } + resp, err := client.R(). SetContext(ctx). SetHeader("Authorization", "Bearer "+accessToken). SetHeader("Content-Type", "application/json"). @@ -66,7 +70,11 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken fmt.Printf("[CodeAssist] OnboardUser request body: %+v\n", reqBody) var out geminicli.OnboardUserResponse - resp, err := createGeminiCliReqClient(proxyURL).R(). + client, err := createGeminiCliReqClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create HTTP client: %w", err) + } + resp, err := client.R(). SetContext(ctx). SetHeader("Authorization", "Bearer "+accessToken). SetHeader("Content-Type", "application/json"). @@ -98,7 +106,7 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken return &out, nil } -func createGeminiCliReqClient(proxyURL string) *req.Client { +func createGeminiCliReqClient(proxyURL string) (*req.Client, error) { return getSharedReqClient(reqClientOptions{ ProxyURL: proxyURL, Timeout: 30 * time.Second, diff --git a/backend/internal/repository/github_release_service.go b/backend/internal/repository/github_release_service.go index 03f8cc66..ad1f22e3 100644 --- a/backend/internal/repository/github_release_service.go +++ b/backend/internal/repository/github_release_service.go @@ -5,8 +5,10 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "net/http" "os" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" @@ -18,14 +20,27 @@ type githubReleaseClient struct { downloadHTTPClient *http.Client } +type githubReleaseClientError struct { + err error +} + // NewGitHubReleaseClient 创建 GitHub Release 客户端 // proxyURL 为空时直连 GitHub,支持 http/https/socks5/socks5h 协议 -func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient { +// 代理配置失败时行为由 allowDirectOnProxyError 控制: +// - false(默认):返回错误占位客户端,禁止回退到直连 +// - true:回退到直连(仅限管理员显式开启) +func NewGitHubReleaseClient(proxyURL string, allowDirectOnProxyError bool) service.GitHubReleaseClient { + // 安全说明:httpclient.GetClient 的错误链(url.Parse / proxyutil)不含明文代理凭据, + // 但仍通过 slog 仅在服务端日志记录,不会暴露给 HTTP 响应。 sharedClient, err := httpclient.GetClient(httpclient.Options{ Timeout: 30 * time.Second, ProxyURL: proxyURL, }) if err != nil { + if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError { + slog.Warn("proxy client init failed, all requests will fail", "service", "github_release", "error", err) + return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)} + } sharedClient = &http.Client{Timeout: 30 * time.Second} } @@ -35,6 +50,10 @@ func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient { ProxyURL: proxyURL, }) if err != nil { + if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError { + slog.Warn("proxy download client init failed, all requests will fail", "service", "github_release", "error", err) + return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)} + } downloadClient = &http.Client{Timeout: 10 * time.Minute} } @@ -44,6 +63,18 @@ func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient { } } +func (c *githubReleaseClientError) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) { + return nil, c.err +} + +func (c *githubReleaseClientError) DownloadFile(ctx context.Context, url, dest string, maxSize int64) error { + return c.err +} + +func (c *githubReleaseClientError) FetchChecksumFile(ctx context.Context, url string) ([]byte, error) { + return nil, c.err +} + func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) { url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo) diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 4e7a836f..4edc8534 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -4,11 +4,13 @@ import ( "context" "database/sql" "errors" - "log" + "fmt" + "strings" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/lib/pq" @@ -47,12 +49,17 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). + SetNillableSoraImagePrice360(groupIn.SoraImagePrice360). + SetNillableSoraImagePrice540(groupIn.SoraImagePrice540). + SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest). + SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD). SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetNillableFallbackGroupID(groupIn.FallbackGroupID). SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). - SetMcpXMLInject(groupIn.MCPXMLInject) + SetMcpXMLInject(groupIn.MCPXMLInject). + SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes) // 设置模型路由配置 if groupIn.ModelRouting != nil { @@ -68,7 +75,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er groupIn.CreatedAt = created.CreatedAt groupIn.UpdatedAt = created.UpdatedAt if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err) + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err) } } return translatePersistenceError(err, nil, service.ErrGroupExists) @@ -110,10 +117,47 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). + SetNillableSoraImagePrice360(groupIn.SoraImagePrice360). + SetNillableSoraImagePrice540(groupIn.SoraImagePrice540). + SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest). + SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD). SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). - SetMcpXMLInject(groupIn.MCPXMLInject) + SetMcpXMLInject(groupIn.MCPXMLInject). + SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes) + + // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。 + if groupIn.DailyLimitUSD != nil { + builder = builder.SetDailyLimitUsd(*groupIn.DailyLimitUSD) + } else { + builder = builder.ClearDailyLimitUsd() + } + if groupIn.WeeklyLimitUSD != nil { + builder = builder.SetWeeklyLimitUsd(*groupIn.WeeklyLimitUSD) + } else { + builder = builder.ClearWeeklyLimitUsd() + } + if groupIn.MonthlyLimitUSD != nil { + builder = builder.SetMonthlyLimitUsd(*groupIn.MonthlyLimitUSD) + } else { + builder = builder.ClearMonthlyLimitUsd() + } + if groupIn.ImagePrice1K != nil { + builder = builder.SetImagePrice1k(*groupIn.ImagePrice1K) + } else { + builder = builder.ClearImagePrice1k() + } + if groupIn.ImagePrice2K != nil { + builder = builder.SetImagePrice2k(*groupIn.ImagePrice2K) + } else { + builder = builder.ClearImagePrice2k() + } + if groupIn.ImagePrice4K != nil { + builder = builder.SetImagePrice4k(*groupIn.ImagePrice4K) + } else { + builder = builder.ClearImagePrice4k() + } // 处理 FallbackGroupID:nil 时清除,否则设置 if groupIn.FallbackGroupID != nil { @@ -144,7 +188,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er } groupIn.UpdatedAt = updated.UpdatedAt if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err) + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err) } return nil } @@ -155,7 +199,7 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error { return translatePersistenceError(err, service.ErrGroupNotFound, nil) } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err) + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err) } return nil } @@ -183,7 +227,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination q = q.Where(group.IsExclusiveEQ(*isExclusive)) } - total, err := q.Count(ctx) + total, err := q.Clone().Count(ctx) if err != nil { return nil, nil, err } @@ -273,6 +317,54 @@ func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, return r.client.Group.Query().Where(group.NameEQ(name)).Exist(ctx) } +// ExistsByIDs 批量检查分组是否存在(仅检查未软删除记录)。 +// 返回结构:map[groupID]exists。 +func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) { + result := make(map[int64]bool, len(ids)) + if len(ids) == 0 { + return result, nil + } + + uniqueIDs := make([]int64, 0, len(ids)) + seen := make(map[int64]struct{}, len(ids)) + for _, id := range ids { + if id <= 0 { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + uniqueIDs = append(uniqueIDs, id) + result[id] = false + } + if len(uniqueIDs) == 0 { + return result, nil + } + + rows, err := r.sql.QueryContext(ctx, ` + SELECT id + FROM groups + WHERE id = ANY($1) AND deleted_at IS NULL + `, pq.Array(uniqueIDs)) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return nil, err + } + result[id] = true + } + if err := rows.Err(); err != nil { + return nil, err + } + return result, nil +} + func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { var count int64 if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil { @@ -288,7 +380,7 @@ func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, grou } affected, _ := res.RowsAffected() if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err) + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err) } return affected, nil } @@ -398,7 +490,7 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, } } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err) + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err) } return affectedUserIDs, nil @@ -492,7 +584,7 @@ func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64 // 发送调度器事件 if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err) + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err) } return nil @@ -504,22 +596,72 @@ func (r *groupRepository) UpdateSortOrders(ctx context.Context, updates []servic return nil } - // 使用事务批量更新 - tx, err := r.client.Tx(ctx) + // 去重后保留最后一次排序值,避免重复 ID 造成 CASE 分支冲突。 + sortOrderByID := make(map[int64]int, len(updates)) + groupIDs := make([]int64, 0, len(updates)) + for _, u := range updates { + if u.ID <= 0 { + continue + } + if _, exists := sortOrderByID[u.ID]; !exists { + groupIDs = append(groupIDs, u.ID) + } + sortOrderByID[u.ID] = u.SortOrder + } + if len(groupIDs) == 0 { + return nil + } + + // 与旧实现保持一致:任何不存在/已删除的分组都返回 not found,且不执行更新。 + var existingCount int + if err := scanSingleRow( + ctx, + r.sql, + `SELECT COUNT(*) FROM groups WHERE deleted_at IS NULL AND id = ANY($1)`, + []any{pq.Array(groupIDs)}, + &existingCount, + ); err != nil { + return err + } + if existingCount != len(groupIDs) { + return service.ErrGroupNotFound + } + + args := make([]any, 0, len(groupIDs)*2+1) + caseClauses := make([]string, 0, len(groupIDs)) + placeholder := 1 + for _, id := range groupIDs { + caseClauses = append(caseClauses, fmt.Sprintf("WHEN $%d THEN $%d", placeholder, placeholder+1)) + args = append(args, id, sortOrderByID[id]) + placeholder += 2 + } + args = append(args, pq.Array(groupIDs)) + + query := fmt.Sprintf(` + UPDATE groups + SET sort_order = CASE id + %s + ELSE sort_order + END + WHERE deleted_at IS NULL AND id = ANY($%d) + `, strings.Join(caseClauses, "\n\t\t\t"), placeholder) + + result, err := r.sql.ExecContext(ctx, query, args...) if err != nil { return err } - defer func() { _ = tx.Rollback() }() - - for _, u := range updates { - if _, err := tx.Group.UpdateOneID(u.ID).SetSortOrder(u.SortOrder).Save(ctx); err != nil { - return translatePersistenceError(err, service.ErrGroupNotFound, nil) - } - } - - if err := tx.Commit(); err != nil { + affected, err := result.RowsAffected() + if err != nil { return err } + if affected != int64(len(groupIDs)) { + return service.ErrGroupNotFound + } + for _, id := range groupIDs { + if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil { + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group sort update failed: group=%d err=%v", id, err) + } + } return nil } diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go index c31a9ec4..4a849a46 100644 --- a/backend/internal/repository/group_repo_integration_test.go +++ b/backend/internal/repository/group_repo_integration_test.go @@ -352,6 +352,81 @@ func (s *GroupRepoSuite) TestListWithFilters_Search() { }) } +func (s *GroupRepoSuite) TestUpdateSortOrders_BatchCaseWhen() { + g1 := &service.Group{ + Name: "sort-g1", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + g2 := &service.Group{ + Name: "sort-g2", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + g3 := &service.Group{ + Name: "sort-g3", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + s.Require().NoError(s.repo.Create(s.ctx, g1)) + s.Require().NoError(s.repo.Create(s.ctx, g2)) + s.Require().NoError(s.repo.Create(s.ctx, g3)) + + err := s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{ + {ID: g1.ID, SortOrder: 30}, + {ID: g2.ID, SortOrder: 10}, + {ID: g3.ID, SortOrder: 20}, + {ID: g2.ID, SortOrder: 15}, // 重复 ID 应以最后一次为准 + }) + s.Require().NoError(err) + + got1, err := s.repo.GetByID(s.ctx, g1.ID) + s.Require().NoError(err) + got2, err := s.repo.GetByID(s.ctx, g2.ID) + s.Require().NoError(err) + got3, err := s.repo.GetByID(s.ctx, g3.ID) + s.Require().NoError(err) + s.Require().Equal(30, got1.SortOrder) + s.Require().Equal(15, got2.SortOrder) + s.Require().Equal(20, got3.SortOrder) +} + +func (s *GroupRepoSuite) TestUpdateSortOrders_MissingGroupNoPartialUpdate() { + g1 := &service.Group{ + Name: "sort-no-partial", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + s.Require().NoError(s.repo.Create(s.ctx, g1)) + + before, err := s.repo.GetByID(s.ctx, g1.ID) + s.Require().NoError(err) + beforeSort := before.SortOrder + + err = s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{ + {ID: g1.ID, SortOrder: 99}, + {ID: 99999999, SortOrder: 1}, + }) + s.Require().Error(err) + s.Require().ErrorIs(err, service.ErrGroupNotFound) + + after, err := s.repo.GetByID(s.ctx, g1.ID) + s.Require().NoError(err) + s.Require().Equal(beforeSort, after.SortOrder) +} + func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { g1 := &service.Group{ Name: "g1", diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index b0f15f19..a4674c1a 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -14,6 +14,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/Wei-Shaw/sub2api/internal/service" @@ -235,7 +236,10 @@ func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID in // TLS 指纹客户端使用独立的缓存键,与普通客户端隔离 func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) { isolation := s.getIsolationMode() - proxyKey, parsedProxy := normalizeProxyURL(proxyURL) + proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL) + if err != nil { + return nil, err + } // TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀 cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID) poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls" @@ -373,9 +377,8 @@ func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, ac // - proxy: 按代理地址隔离,同一代理共享客户端 // - account: 按账户隔离,同一账户共享客户端(代理变更时重建) // - account_proxy: 按账户+代理组合隔离,最细粒度 -func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry { - entry, _ := s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false) - return entry +func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) { + return s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false) } // getClientEntry 获取或创建客户端条目 @@ -385,7 +388,10 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a // 获取隔离模式 isolation := s.getIsolationMode() // 标准化代理 URL 并解析 - proxyKey, parsedProxy := normalizeProxyURL(proxyURL) + proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL) + if err != nil { + return nil, err + } // 构建缓存键(根据隔离策略不同) cacheKey := buildCacheKey(isolation, proxyKey, accountID) // 构建连接池配置键(用于检测配置变更) @@ -680,17 +686,18 @@ func buildCacheKey(isolation, proxyKey string, accountID int64) string { // - raw: 原始代理 URL 字符串 // // 返回: -// - string: 标准化的代理键(空或解析失败返回 "direct") -// - *url.URL: 解析后的 URL(空或解析失败返回 nil) -func normalizeProxyURL(raw string) (string, *url.URL) { - proxyURL := strings.TrimSpace(raw) - if proxyURL == "" { - return directProxyKey, nil - } - parsed, err := url.Parse(proxyURL) +// - string: 标准化的代理键(空返回 "direct") +// - *url.URL: 解析后的 URL(空返回 nil) +// - error: 非空代理 URL 解析失败时返回错误(禁止回退到直连) +func normalizeProxyURL(raw string) (string, *url.URL, error) { + _, parsed, err := proxyurl.Parse(raw) if err != nil { - return directProxyKey, nil + return "", nil, err } + if parsed == nil { + return directProxyKey, nil, nil + } + // 规范化:小写 scheme/host,去除路径和查询参数 parsed.Scheme = strings.ToLower(parsed.Scheme) parsed.Host = strings.ToLower(parsed.Host) parsed.Path = "" @@ -710,7 +717,7 @@ func normalizeProxyURL(raw string) (string, *url.URL) { parsed.Host = hostname } } - return parsed.String(), parsed + return parsed.String(), parsed, nil } // defaultPoolSettings 获取默认连接池配置 diff --git a/backend/internal/repository/http_upstream_benchmark_test.go b/backend/internal/repository/http_upstream_benchmark_test.go index 1e7430a3..89892b3b 100644 --- a/backend/internal/repository/http_upstream_benchmark_test.go +++ b/backend/internal/repository/http_upstream_benchmark_test.go @@ -59,7 +59,10 @@ func BenchmarkHTTPUpstreamProxyClient(b *testing.B) { // 模拟优化后的行为,从缓存获取客户端 b.Run("复用", func(b *testing.B) { // 预热:确保客户端已缓存 - entry := svc.getOrCreateClient(proxyURL, 1, 1) + entry, err := svc.getOrCreateClient(proxyURL, 1, 1) + if err != nil { + b.Fatalf("getOrCreateClient: %v", err) + } client := entry.client b.ResetTimer() // 重置计时器,排除预热时间 for i := 0; i < b.N; i++ { diff --git a/backend/internal/repository/http_upstream_test.go b/backend/internal/repository/http_upstream_test.go index fbe44c5e..b3268463 100644 --- a/backend/internal/repository/http_upstream_test.go +++ b/backend/internal/repository/http_upstream_test.go @@ -44,7 +44,7 @@ func (s *HTTPUpstreamSuite) newService() *httpUpstreamService { // 验证未配置时使用 300 秒默认值 func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() { svc := s.newService() - entry := svc.getOrCreateClient("", 0, 0) + entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") @@ -55,25 +55,27 @@ func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() { func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() { s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7} svc := s.newService() - entry := svc.getOrCreateClient("", 0, 0) + entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") } -// TestGetOrCreateClient_InvalidURLFallsBackToDirect 测试无效代理 URL 回退 -// 验证解析失败时回退到直连模式 -func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDirect() { +// TestGetOrCreateClient_InvalidURLReturnsError 测试无效代理 URL 返回错误 +// 验证解析失败时拒绝回退到直连模式 +func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLReturnsError() { svc := s.newService() - entry := svc.getOrCreateClient("://bad-proxy-url", 1, 1) - require.Equal(s.T(), directProxyKey, entry.proxyKey, "expected direct proxy fallback") + _, err := svc.getClientEntry("://bad-proxy-url", 1, 1, false, false) + require.Error(s.T(), err, "expected error for invalid proxy URL") } // TestNormalizeProxyURL_Canonicalizes 测试代理 URL 规范化 // 验证等价地址能够映射到同一缓存键 func (s *HTTPUpstreamSuite) TestNormalizeProxyURL_Canonicalizes() { - key1, _ := normalizeProxyURL("http://proxy.local:8080") - key2, _ := normalizeProxyURL("http://proxy.local:8080/") + key1, _, err1 := normalizeProxyURL("http://proxy.local:8080") + require.NoError(s.T(), err1) + key2, _, err2 := normalizeProxyURL("http://proxy.local:8080/") + require.NoError(s.T(), err2) require.Equal(s.T(), key1, key2, "expected normalized proxy keys to match") } @@ -171,8 +173,8 @@ func (s *HTTPUpstreamSuite) TestAccountIsolation_DifferentAccounts() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount} svc := s.newService() // 同一代理,不同账户 - entry1 := svc.getOrCreateClient("http://proxy.local:8080", 1, 3) - entry2 := svc.getOrCreateClient("http://proxy.local:8080", 2, 3) + entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy.local:8080", 1, 3) + entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy.local:8080", 2, 3) require.NotSame(s.T(), entry1, entry2, "不同账号不应共享连接池") require.Equal(s.T(), 2, len(svc.clients), "账号隔离应缓存两个客户端") } @@ -183,8 +185,8 @@ func (s *HTTPUpstreamSuite) TestAccountProxyIsolation_DifferentProxy() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy} svc := s.newService() // 同一账户,不同代理 - entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3) - entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3) + entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 3) + entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 1, 3) require.NotSame(s.T(), entry1, entry2, "账号+代理隔离应区分不同代理") require.Equal(s.T(), 2, len(svc.clients), "账号+代理隔离应缓存两个客户端") } @@ -195,8 +197,8 @@ func (s *HTTPUpstreamSuite) TestAccountModeProxyChangeClearsPool() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount} svc := s.newService() // 同一账户,先后使用不同代理 - entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3) - entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3) + entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 3) + entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 1, 3) require.NotSame(s.T(), entry1, entry2, "账号切换代理应创建新连接池") require.Equal(s.T(), 1, len(svc.clients), "账号模式下应仅保留一个连接池") require.False(s.T(), hasEntry(svc, entry1), "旧连接池应被清理") @@ -208,7 +210,7 @@ func (s *HTTPUpstreamSuite) TestAccountConcurrencyOverridesPoolSettings() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount} svc := s.newService() // 账户并发数为 12 - entry := svc.getOrCreateClient("", 1, 12) + entry := mustGetOrCreateClient(s.T(), svc, "", 1, 12) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") // 连接池参数应与并发数一致 @@ -228,7 +230,7 @@ func (s *HTTPUpstreamSuite) TestAccountConcurrencyFallbackToDefault() { } svc := s.newService() // 账户并发数为 0,应使用全局配置 - entry := svc.getOrCreateClient("", 1, 0) + entry := mustGetOrCreateClient(s.T(), svc, "", 1, 0) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") require.Equal(s.T(), 66, transport.MaxConnsPerHost, "MaxConnsPerHost fallback mismatch") @@ -245,12 +247,12 @@ func (s *HTTPUpstreamSuite) TestEvictOverLimitRemovesOldestIdle() { } svc := s.newService() // 创建两个客户端,设置不同的最后使用时间 - entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 1) - entry2 := svc.getOrCreateClient("http://proxy-b:8080", 2, 1) + entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 1) + entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 2, 1) atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Hour).UnixNano()) // 最久 atomic.StoreInt64(&entry2.lastUsed, time.Now().Add(-time.Hour).UnixNano()) // 创建第三个客户端,触发淘汰 - _ = svc.getOrCreateClient("http://proxy-c:8080", 3, 1) + _ = mustGetOrCreateClient(s.T(), svc, "http://proxy-c:8080", 3, 1) require.LessOrEqual(s.T(), len(svc.clients), 2, "应保持在缓存上限内") require.False(s.T(), hasEntry(svc, entry1), "最久未使用的连接池应被清理") @@ -264,12 +266,12 @@ func (s *HTTPUpstreamSuite) TestIdleTTLDoesNotEvictActive() { ClientIdleTTLSeconds: 1, // 1 秒空闲超时 } svc := s.newService() - entry1 := svc.getOrCreateClient("", 1, 1) + entry1 := mustGetOrCreateClient(s.T(), svc, "", 1, 1) // 设置为很久之前使用,但有活跃请求 atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Minute).UnixNano()) atomic.StoreInt64(&entry1.inFlight, 1) // 模拟有活跃请求 // 创建新客户端,触发淘汰检查 - _ = svc.getOrCreateClient("", 2, 1) + _, _ = svc.getOrCreateClient("", 2, 1) require.True(s.T(), hasEntry(svc, entry1), "有活跃请求时不应回收") } @@ -279,6 +281,14 @@ func TestHTTPUpstreamSuite(t *testing.T) { suite.Run(t, new(HTTPUpstreamSuite)) } +// mustGetOrCreateClient 测试辅助函数,调用 getOrCreateClient 并断言无错误 +func mustGetOrCreateClient(t *testing.T, svc *httpUpstreamService, proxyURL string, accountID int64, concurrency int) *upstreamClientEntry { + t.Helper() + entry, err := svc.getOrCreateClient(proxyURL, accountID, concurrency) + require.NoError(t, err, "getOrCreateClient(%q, %d, %d)", proxyURL, accountID, concurrency) + return entry +} + // hasEntry 检查客户端是否存在于缓存中 // 辅助函数,用于验证淘汰逻辑 func hasEntry(svc *httpUpstreamService, target *upstreamClientEntry) bool { diff --git a/backend/internal/repository/idempotency_repo.go b/backend/internal/repository/idempotency_repo.go new file mode 100644 index 00000000..32f2faae --- /dev/null +++ b/backend/internal/repository/idempotency_repo.go @@ -0,0 +1,237 @@ +package repository + +import ( + "context" + "database/sql" + "errors" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type idempotencyRepository struct { + sql sqlExecutor +} + +func NewIdempotencyRepository(_ *dbent.Client, sqlDB *sql.DB) service.IdempotencyRepository { + return &idempotencyRepository{sql: sqlDB} +} + +func (r *idempotencyRepository) CreateProcessing(ctx context.Context, record *service.IdempotencyRecord) (bool, error) { + if record == nil { + return false, nil + } + query := ` + INSERT INTO idempotency_records ( + scope, idempotency_key_hash, request_fingerprint, status, locked_until, expires_at + ) VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (scope, idempotency_key_hash) DO NOTHING + RETURNING id, created_at, updated_at + ` + var createdAt time.Time + var updatedAt time.Time + err := scanSingleRow(ctx, r.sql, query, []any{ + record.Scope, + record.IdempotencyKeyHash, + record.RequestFingerprint, + record.Status, + record.LockedUntil, + record.ExpiresAt, + }, &record.ID, &createdAt, &updatedAt) + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + if err != nil { + return false, err + } + record.CreatedAt = createdAt + record.UpdatedAt = updatedAt + return true, nil +} + +func (r *idempotencyRepository) GetByScopeAndKeyHash(ctx context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) { + query := ` + SELECT + id, scope, idempotency_key_hash, request_fingerprint, status, response_status, + response_body, error_reason, locked_until, expires_at, created_at, updated_at + FROM idempotency_records + WHERE scope = $1 AND idempotency_key_hash = $2 + ` + record := &service.IdempotencyRecord{} + var responseStatus sql.NullInt64 + var responseBody sql.NullString + var errorReason sql.NullString + var lockedUntil sql.NullTime + err := scanSingleRow(ctx, r.sql, query, []any{scope, keyHash}, + &record.ID, + &record.Scope, + &record.IdempotencyKeyHash, + &record.RequestFingerprint, + &record.Status, + &responseStatus, + &responseBody, + &errorReason, + &lockedUntil, + &record.ExpiresAt, + &record.CreatedAt, + &record.UpdatedAt, + ) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + if err != nil { + return nil, err + } + if responseStatus.Valid { + v := int(responseStatus.Int64) + record.ResponseStatus = &v + } + if responseBody.Valid { + v := responseBody.String + record.ResponseBody = &v + } + if errorReason.Valid { + v := errorReason.String + record.ErrorReason = &v + } + if lockedUntil.Valid { + v := lockedUntil.Time + record.LockedUntil = &v + } + return record, nil +} + +func (r *idempotencyRepository) TryReclaim( + ctx context.Context, + id int64, + fromStatus string, + now, newLockedUntil, newExpiresAt time.Time, +) (bool, error) { + query := ` + UPDATE idempotency_records + SET status = $2, + locked_until = $3, + error_reason = NULL, + updated_at = NOW(), + expires_at = $4 + WHERE id = $1 + AND status = $5 + AND (locked_until IS NULL OR locked_until <= $6) + ` + res, err := r.sql.ExecContext(ctx, query, + id, + service.IdempotencyStatusProcessing, + newLockedUntil, + newExpiresAt, + fromStatus, + now, + ) + if err != nil { + return false, err + } + affected, err := res.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +func (r *idempotencyRepository) ExtendProcessingLock( + ctx context.Context, + id int64, + requestFingerprint string, + newLockedUntil, + newExpiresAt time.Time, +) (bool, error) { + query := ` + UPDATE idempotency_records + SET locked_until = $2, + expires_at = $3, + updated_at = NOW() + WHERE id = $1 + AND status = $4 + AND request_fingerprint = $5 + ` + res, err := r.sql.ExecContext( + ctx, + query, + id, + newLockedUntil, + newExpiresAt, + service.IdempotencyStatusProcessing, + requestFingerprint, + ) + if err != nil { + return false, err + } + affected, err := res.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +func (r *idempotencyRepository) MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error { + query := ` + UPDATE idempotency_records + SET status = $2, + response_status = $3, + response_body = $4, + error_reason = NULL, + locked_until = NULL, + expires_at = $5, + updated_at = NOW() + WHERE id = $1 + ` + _, err := r.sql.ExecContext(ctx, query, + id, + service.IdempotencyStatusSucceeded, + responseStatus, + responseBody, + expiresAt, + ) + return err +} + +func (r *idempotencyRepository) MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error { + query := ` + UPDATE idempotency_records + SET status = $2, + error_reason = $3, + locked_until = $4, + expires_at = $5, + updated_at = NOW() + WHERE id = $1 + ` + _, err := r.sql.ExecContext(ctx, query, + id, + service.IdempotencyStatusFailedRetryable, + errorReason, + lockedUntil, + expiresAt, + ) + return err +} + +func (r *idempotencyRepository) DeleteExpired(ctx context.Context, now time.Time, limit int) (int64, error) { + if limit <= 0 { + limit = 500 + } + query := ` + WITH victims AS ( + SELECT id + FROM idempotency_records + WHERE expires_at <= $1 + ORDER BY expires_at ASC + LIMIT $2 + ) + DELETE FROM idempotency_records + WHERE id IN (SELECT id FROM victims) + ` + res, err := r.sql.ExecContext(ctx, query, now, limit) + if err != nil { + return 0, err + } + return res.RowsAffected() +} diff --git a/backend/internal/repository/idempotency_repo_integration_test.go b/backend/internal/repository/idempotency_repo_integration_test.go new file mode 100644 index 00000000..f163c2f0 --- /dev/null +++ b/backend/internal/repository/idempotency_repo_integration_test.go @@ -0,0 +1,149 @@ +//go:build integration + +package repository + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +// hashedTestValue returns a unique SHA-256 hex string (64 chars) that fits VARCHAR(64) columns. +func hashedTestValue(t *testing.T, prefix string) string { + t.Helper() + sum := sha256.Sum256([]byte(uniqueTestValue(t, prefix))) + return hex.EncodeToString(sum[:]) +} + +func TestIdempotencyRepo_CreateProcessing_CompeteSameKey(t *testing.T) { + tx := testTx(t) + repo := &idempotencyRepository{sql: tx} + ctx := context.Background() + + now := time.Now().UTC() + record := &service.IdempotencyRecord{ + Scope: uniqueTestValue(t, "idem-scope-create"), + IdempotencyKeyHash: hashedTestValue(t, "idem-hash"), + RequestFingerprint: hashedTestValue(t, "idem-fp"), + Status: service.IdempotencyStatusProcessing, + LockedUntil: ptrTime(now.Add(30 * time.Second)), + ExpiresAt: now.Add(24 * time.Hour), + } + owner, err := repo.CreateProcessing(ctx, record) + require.NoError(t, err) + require.True(t, owner) + require.NotZero(t, record.ID) + + duplicate := &service.IdempotencyRecord{ + Scope: record.Scope, + IdempotencyKeyHash: record.IdempotencyKeyHash, + RequestFingerprint: hashedTestValue(t, "idem-fp-other"), + Status: service.IdempotencyStatusProcessing, + LockedUntil: ptrTime(now.Add(30 * time.Second)), + ExpiresAt: now.Add(24 * time.Hour), + } + owner, err = repo.CreateProcessing(ctx, duplicate) + require.NoError(t, err) + require.False(t, owner, "same scope+key hash should be de-duplicated") +} + +func TestIdempotencyRepo_TryReclaim_StatusAndLockWindow(t *testing.T) { + tx := testTx(t) + repo := &idempotencyRepository{sql: tx} + ctx := context.Background() + + now := time.Now().UTC() + record := &service.IdempotencyRecord{ + Scope: uniqueTestValue(t, "idem-scope-reclaim"), + IdempotencyKeyHash: hashedTestValue(t, "idem-hash-reclaim"), + RequestFingerprint: hashedTestValue(t, "idem-fp-reclaim"), + Status: service.IdempotencyStatusProcessing, + LockedUntil: ptrTime(now.Add(10 * time.Second)), + ExpiresAt: now.Add(24 * time.Hour), + } + owner, err := repo.CreateProcessing(ctx, record) + require.NoError(t, err) + require.True(t, owner) + + require.NoError(t, repo.MarkFailedRetryable( + ctx, + record.ID, + "RETRYABLE_FAILURE", + now.Add(-2*time.Second), + now.Add(24*time.Hour), + )) + + newLockedUntil := now.Add(20 * time.Second) + reclaimed, err := repo.TryReclaim( + ctx, + record.ID, + service.IdempotencyStatusFailedRetryable, + now, + newLockedUntil, + now.Add(24*time.Hour), + ) + require.NoError(t, err) + require.True(t, reclaimed, "failed_retryable + expired lock should allow reclaim") + + got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, service.IdempotencyStatusProcessing, got.Status) + require.NotNil(t, got.LockedUntil) + require.True(t, got.LockedUntil.After(now)) + + require.NoError(t, repo.MarkFailedRetryable( + ctx, + record.ID, + "RETRYABLE_FAILURE", + now.Add(20*time.Second), + now.Add(24*time.Hour), + )) + + reclaimed, err = repo.TryReclaim( + ctx, + record.ID, + service.IdempotencyStatusFailedRetryable, + now, + now.Add(40*time.Second), + now.Add(24*time.Hour), + ) + require.NoError(t, err) + require.False(t, reclaimed, "within lock window should not reclaim") +} + +func TestIdempotencyRepo_StatusTransition_ToSucceeded(t *testing.T) { + tx := testTx(t) + repo := &idempotencyRepository{sql: tx} + ctx := context.Background() + + now := time.Now().UTC() + record := &service.IdempotencyRecord{ + Scope: uniqueTestValue(t, "idem-scope-success"), + IdempotencyKeyHash: hashedTestValue(t, "idem-hash-success"), + RequestFingerprint: hashedTestValue(t, "idem-fp-success"), + Status: service.IdempotencyStatusProcessing, + LockedUntil: ptrTime(now.Add(10 * time.Second)), + ExpiresAt: now.Add(24 * time.Hour), + } + owner, err := repo.CreateProcessing(ctx, record) + require.NoError(t, err) + require.True(t, owner) + + require.NoError(t, repo.MarkSucceeded(ctx, record.ID, 200, `{"ok":true}`, now.Add(24*time.Hour))) + + got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, service.IdempotencyStatusSucceeded, got.Status) + require.NotNil(t, got.ResponseStatus) + require.Equal(t, 200, *got.ResponseStatus) + require.NotNil(t, got.ResponseBody) + require.Equal(t, `{"ok":true}`, *got.ResponseBody) + require.Nil(t, got.LockedUntil) +} diff --git a/backend/internal/repository/identity_cache.go b/backend/internal/repository/identity_cache.go index c4986547..6152dd7a 100644 --- a/backend/internal/repository/identity_cache.go +++ b/backend/internal/repository/identity_cache.go @@ -12,7 +12,7 @@ import ( const ( fingerprintKeyPrefix = "fingerprint:" - fingerprintTTL = 24 * time.Hour + fingerprintTTL = 7 * 24 * time.Hour // 7天,配合每24小时懒续期可保持活跃账号永不过期 maskedSessionKeyPrefix = "masked_session:" maskedSessionTTL = 15 * time.Minute ) diff --git a/backend/internal/repository/migrations_runner.go b/backend/internal/repository/migrations_runner.go index 5912e50f..9cf3b392 100644 --- a/backend/internal/repository/migrations_runner.go +++ b/backend/internal/repository/migrations_runner.go @@ -50,6 +50,30 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions ( // 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。 const migrationsAdvisoryLockID int64 = 694208311321144027 const migrationsLockRetryInterval = 500 * time.Millisecond +const nonTransactionalMigrationSuffix = "_notx.sql" + +type migrationChecksumCompatibilityRule struct { + fileChecksum string + acceptedDBChecksum map[string]struct{} +} + +// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。 +// 规则必须同时匹配「迁移名 + 当前文件 checksum + 历史库 checksum」才会放行,避免放宽全局校验。 +var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibilityRule{ + "054_drop_legacy_cache_columns.sql": { + fileChecksum: "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", + acceptedDBChecksum: map[string]struct{}{ + "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {}, + }, + }, + "061_add_usage_log_request_type.sql": { + fileChecksum: "66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", + acceptedDBChecksum: map[string]struct{}{ + "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0": {}, + "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3": {}, + }, + }, +} // ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。 // @@ -147,6 +171,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error { if rowErr == nil { // 迁移已应用,验证校验和是否匹配 if existing != checksum { + // 兼容特定历史误改场景(仅白名单规则),其余仍保持严格不可变约束。 + if isMigrationChecksumCompatible(name, existing, checksum) { + continue + } // 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。 // 正确的做法是创建新的迁移文件来进行变更。 return fmt.Errorf( @@ -165,8 +193,34 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error { return fmt.Errorf("check migration %s: %w", name, rowErr) } - // 迁移未应用,在事务中执行。 - // 使用事务确保迁移的原子性:要么完全成功,要么完全回滚。 + nonTx, err := validateMigrationExecutionMode(name, content) + if err != nil { + return fmt.Errorf("validate migration %s: %w", name, err) + } + + if nonTx { + // *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。 + // 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。 + statements := splitSQLStatements(content) + for i, stmt := range statements { + trimmed := strings.TrimSpace(stmt) + if trimmed == "" { + continue + } + if stripSQLLineComment(trimmed) == "" { + continue + } + if _, err := db.ExecContext(ctx, trimmed); err != nil { + return fmt.Errorf("apply migration %s (non-tx statement %d): %w", name, i+1, err) + } + } + if _, err := db.ExecContext(ctx, "INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)", name, checksum); err != nil { + return fmt.Errorf("record migration %s (non-tx): %w", name, err) + } + continue + } + + // 默认迁移在事务中执行,确保原子性:要么完全成功,要么完全回滚。 tx, err := db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("begin migration %s: %w", name, err) @@ -268,6 +322,84 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) { return version, version, hash, nil } +func isMigrationChecksumCompatible(name, dbChecksum, fileChecksum string) bool { + rule, ok := migrationChecksumCompatibilityRules[name] + if !ok { + return false + } + if rule.fileChecksum != fileChecksum { + return false + } + _, ok = rule.acceptedDBChecksum[dbChecksum] + return ok +} + +func validateMigrationExecutionMode(name, content string) (bool, error) { + normalizedName := strings.ToLower(strings.TrimSpace(name)) + upperContent := strings.ToUpper(content) + nonTx := strings.HasSuffix(normalizedName, nonTransactionalMigrationSuffix) + + if !nonTx { + if strings.Contains(upperContent, "CONCURRENTLY") { + return false, errors.New("CONCURRENTLY statements must be placed in *_notx.sql migrations") + } + return false, nil + } + + if strings.Contains(upperContent, "BEGIN") || strings.Contains(upperContent, "COMMIT") || strings.Contains(upperContent, "ROLLBACK") { + return false, errors.New("*_notx.sql must not contain transaction control statements (BEGIN/COMMIT/ROLLBACK)") + } + + statements := splitSQLStatements(content) + for _, stmt := range statements { + normalizedStmt := strings.ToUpper(stripSQLLineComment(strings.TrimSpace(stmt))) + if normalizedStmt == "" { + continue + } + + if strings.Contains(normalizedStmt, "CONCURRENTLY") { + isCreateIndex := strings.Contains(normalizedStmt, "CREATE") && strings.Contains(normalizedStmt, "INDEX") + isDropIndex := strings.Contains(normalizedStmt, "DROP") && strings.Contains(normalizedStmt, "INDEX") + if !isCreateIndex && !isDropIndex { + return false, errors.New("*_notx.sql currently only supports CREATE/DROP INDEX CONCURRENTLY statements") + } + if isCreateIndex && !strings.Contains(normalizedStmt, "IF NOT EXISTS") { + return false, errors.New("CREATE INDEX CONCURRENTLY in *_notx.sql must include IF NOT EXISTS for idempotency") + } + if isDropIndex && !strings.Contains(normalizedStmt, "IF EXISTS") { + return false, errors.New("DROP INDEX CONCURRENTLY in *_notx.sql must include IF EXISTS for idempotency") + } + continue + } + + return false, errors.New("*_notx.sql must not mix non-CONCURRENTLY SQL statements") + } + + return true, nil +} + +func splitSQLStatements(content string) []string { + parts := strings.Split(content, ";") + out := make([]string, 0, len(parts)) + for _, part := range parts { + if strings.TrimSpace(part) == "" { + continue + } + out = append(out, part) + } + return out +} + +func stripSQLLineComment(s string) string { + lines := strings.Split(s, "\n") + for i, line := range lines { + if idx := strings.Index(line, "--"); idx >= 0 { + lines[i] = line[:idx] + } + } + return strings.TrimSpace(strings.Join(lines, "\n")) +} + // pgAdvisoryLock 获取 PostgreSQL Advisory Lock。 // Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。 // 它非常适合用于应用层面的分布式锁场景,如迁移序列化。 diff --git a/backend/internal/repository/migrations_runner_checksum_test.go b/backend/internal/repository/migrations_runner_checksum_test.go new file mode 100644 index 00000000..6c3ad725 --- /dev/null +++ b/backend/internal/repository/migrations_runner_checksum_test.go @@ -0,0 +1,54 @@ +package repository + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsMigrationChecksumCompatible(t *testing.T) { + t.Run("054历史checksum可兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "054_drop_legacy_cache_columns.sql", + "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4", + "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", + ) + require.True(t, ok) + }) + + t.Run("054在未知文件checksum下不兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "054_drop_legacy_cache_columns.sql", + "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4", + "0000000000000000000000000000000000000000000000000000000000000000", + ) + require.False(t, ok) + }) + + t.Run("061历史checksum可兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "061_add_usage_log_request_type.sql", + "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0", + "66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", + ) + require.True(t, ok) + }) + + t.Run("061第二个历史checksum可兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "061_add_usage_log_request_type.sql", + "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3", + "66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", + ) + require.True(t, ok) + }) + + t.Run("非白名单迁移不兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "001_init.sql", + "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4", + "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", + ) + require.False(t, ok) + }) +} diff --git a/backend/internal/repository/migrations_runner_extra_test.go b/backend/internal/repository/migrations_runner_extra_test.go new file mode 100644 index 00000000..9f8a94c6 --- /dev/null +++ b/backend/internal/repository/migrations_runner_extra_test.go @@ -0,0 +1,368 @@ +package repository + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "io/fs" + "strings" + "testing" + "testing/fstest" + "time" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/require" +) + +func TestApplyMigrations_NilDB(t *testing.T) { + err := ApplyMigrations(context.Background(), nil) + require.Error(t, err) + require.Contains(t, err.Error(), "nil sql db") +} + +func TestApplyMigrations_DelegatesToApplyMigrationsFS(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnError(errors.New("lock failed")) + + err = ApplyMigrations(context.Background(), db) + require.Error(t, err) + require.Contains(t, err.Error(), "acquire migrations lock") + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestLatestMigrationBaseline(t *testing.T) { + t.Run("empty_fs_returns_baseline", func(t *testing.T) { + version, description, hash, err := latestMigrationBaseline(fstest.MapFS{}) + require.NoError(t, err) + require.Equal(t, "baseline", version) + require.Equal(t, "baseline", description) + require.Equal(t, "", hash) + }) + + t.Run("uses_latest_sorted_sql_file", func(t *testing.T) { + fsys := fstest.MapFS{ + "001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")}, + "010_final.sql": &fstest.MapFile{ + Data: []byte("CREATE TABLE t2(id int);"), + }, + } + version, description, hash, err := latestMigrationBaseline(fsys) + require.NoError(t, err) + require.Equal(t, "010_final", version) + require.Equal(t, "010_final", description) + require.Len(t, hash, 64) + }) + + t.Run("read_file_error", func(t *testing.T) { + fsys := fstest.MapFS{ + "010_bad.sql": &fstest.MapFile{Mode: fs.ModeDir}, + } + _, _, _, err := latestMigrationBaseline(fsys) + require.Error(t, err) + }) +} + +func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) { + require.False(t, isMigrationChecksumCompatible("unknown.sql", "db", "file")) + + var ( + name string + rule migrationChecksumCompatibilityRule + ) + for n, r := range migrationChecksumCompatibilityRules { + name = n + rule = r + break + } + require.NotEmpty(t, name) + + require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", "file-not-match")) + require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", rule.fileChecksum)) + + var accepted string + for checksum := range rule.acceptedDBChecksum { + accepted = checksum + break + } + require.NotEmpty(t, accepted) + require.True(t, isMigrationChecksumCompatible(name, accepted, rule.fileChecksum)) +} + +func TestEnsureAtlasBaselineAligned(t *testing.T) { + t.Run("skip_when_no_legacy_table", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + + err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{}) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("create_atlas_and_insert_baseline_when_empty", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + mock.ExpectExec("INSERT INTO atlas_schema_revisions"). + WithArgs("002_next", "002_next", 1, sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + fsys := fstest.MapFS{ + "001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")}, + "002_next.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t2(id int);")}, + } + err = ensureAtlasBaselineAligned(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("error_when_checking_legacy_table", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnError(errors.New("exists failed")) + + err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{}) + require.Error(t, err) + require.Contains(t, err.Error(), "check schema_migrations") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("error_when_counting_atlas_rows", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions"). + WillReturnError(errors.New("count failed")) + + err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{}) + require.Error(t, err) + require.Contains(t, err.Error(), "count atlas_schema_revisions") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("error_when_creating_atlas_table", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions"). + WillReturnError(errors.New("create failed")) + + err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{}) + require.Error(t, err) + require.Contains(t, err.Error(), "create atlas_schema_revisions") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("error_when_inserting_baseline", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + mock.ExpectExec("INSERT INTO atlas_schema_revisions"). + WithArgs("001_init", "001_init", 1, sqlmock.AnyArg()). + WillReturnError(errors.New("insert failed")) + + fsys := fstest.MapFS{ + "001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")}, + } + err = ensureAtlasBaselineAligned(context.Background(), db, fsys) + require.Error(t, err) + require.Contains(t, err.Error(), "insert atlas baseline") + require.NoError(t, mock.ExpectationsWereMet()) + }) +} + +func TestApplyMigrationsFS_ChecksumMismatchRejected(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_init.sql"). + WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow("mismatched-checksum")) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")}, + } + err = applyMigrationsFS(context.Background(), db, fsys) + require.Error(t, err) + require.Contains(t, err.Error(), "checksum mismatch") + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_CheckMigrationQueryError(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_err.sql"). + WillReturnError(errors.New("query failed")) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_err.sql": &fstest.MapFile{Data: []byte("SELECT 1;")}, + } + err = applyMigrationsFS(context.Background(), db, fsys) + require.Error(t, err) + require.Contains(t, err.Error(), "check migration 001_err.sql") + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_SkipEmptyAndAlreadyApplied(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + + alreadySQL := "CREATE TABLE t(id int);" + checksum := migrationChecksum(alreadySQL) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_already.sql"). + WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow(checksum)) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "000_empty.sql": &fstest.MapFile{Data: []byte(" \n\t ")}, + "001_already.sql": &fstest.MapFile{Data: []byte(alreadySQL)}, + } + err = applyMigrationsFS(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_ReadMigrationError(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_bad.sql": &fstest.MapFile{Mode: fs.ModeDir}, + } + err = applyMigrationsFS(context.Background(), db, fsys) + require.Error(t, err) + require.Contains(t, err.Error(), "read migration 001_bad.sql") + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPgAdvisoryLockAndUnlock_ErrorBranches(t *testing.T) { + t.Run("context_cancelled_while_not_locked", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false)) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + err = pgAdvisoryLock(ctx, db) + require.Error(t, err) + require.Contains(t, err.Error(), "acquire migrations lock") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("unlock_exec_error", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnError(errors.New("unlock failed")) + + err = pgAdvisoryUnlock(context.Background(), db) + require.Error(t, err) + require.Contains(t, err.Error(), "release migrations lock") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("acquire_lock_after_retry", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false)) + mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true)) + + ctx, cancel := context.WithTimeout(context.Background(), migrationsLockRetryInterval*3) + defer cancel() + start := time.Now() + err = pgAdvisoryLock(ctx, db) + require.NoError(t, err) + require.GreaterOrEqual(t, time.Since(start), migrationsLockRetryInterval) + require.NoError(t, mock.ExpectationsWereMet()) + }) +} + +func migrationChecksum(content string) string { + sum := sha256.Sum256([]byte(strings.TrimSpace(content))) + return hex.EncodeToString(sum[:]) +} diff --git a/backend/internal/repository/migrations_runner_notx_test.go b/backend/internal/repository/migrations_runner_notx_test.go new file mode 100644 index 00000000..db1183cd --- /dev/null +++ b/backend/internal/repository/migrations_runner_notx_test.go @@ -0,0 +1,164 @@ +package repository + +import ( + "context" + "database/sql" + "testing" + "testing/fstest" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/require" +) + +func TestValidateMigrationExecutionMode(t *testing.T) { + t.Run("事务迁移包含CONCURRENTLY会被拒绝", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_add_idx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);") + require.False(t, nonTx) + require.Error(t, err) + }) + + t.Run("notx迁移要求CREATE使用IF NOT EXISTS", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);") + require.False(t, nonTx) + require.Error(t, err) + }) + + t.Run("notx迁移要求DROP使用IF EXISTS", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_drop_idx_notx.sql", "DROP INDEX CONCURRENTLY idx_a;") + require.False(t, nonTx) + require.Error(t, err) + }) + + t.Run("notx迁移禁止事务控制语句", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "BEGIN; CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); COMMIT;") + require.False(t, nonTx) + require.Error(t, err) + }) + + t.Run("notx迁移禁止混用非CONCURRENTLY语句", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); UPDATE t SET a = 1;") + require.False(t, nonTx) + require.Error(t, err) + }) + + t.Run("notx迁移允许幂等并发索引语句", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", ` +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); +DROP INDEX CONCURRENTLY IF EXISTS idx_b; +`) + require.True(t, nonTx) + require.NoError(t, err) + }) +} + +func TestApplyMigrationsFS_NonTransactionalMigration(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_add_idx_notx.sql"). + WillReturnError(sql.ErrNoRows) + mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)"). + WithArgs("001_add_idx_notx.sql", sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_add_idx_notx.sql": &fstest.MapFile{ + Data: []byte("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a);"), + }, + } + + err = applyMigrationsFS(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_NonTransactionalMigration_MultiStatements(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_add_multi_idx_notx.sql"). + WillReturnError(sql.ErrNoRows) + mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t\\(b\\)"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)"). + WithArgs("001_add_multi_idx_notx.sql", sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_add_multi_idx_notx.sql": &fstest.MapFile{ + Data: []byte(` +-- first +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a); +-- second +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b); +`), + }, + } + + err = applyMigrationsFS(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_add_col.sql"). + WillReturnError(sql.ErrNoRows) + mock.ExpectBegin() + mock.ExpectExec("ALTER TABLE t ADD COLUMN name TEXT"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)"). + WithArgs("001_add_col.sql", sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_add_col.sql": &fstest.MapFile{ + Data: []byte("ALTER TABLE t ADD COLUMN name TEXT;"), + }, + } + + err = applyMigrationsFS(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func prepareMigrationsBootstrapExpectations(mock sqlmock.Sqlmock) { + mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true)) + mock.ExpectExec("CREATE TABLE IF NOT EXISTS schema_migrations"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) +} diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go index bc37ee72..72422d18 100644 --- a/backend/internal/repository/migrations_schema_integration_test.go +++ b/backend/internal/repository/migrations_schema_integration_test.go @@ -42,12 +42,19 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { // usage_logs: billing_type used by filters/stats requireColumn(t, tx, "usage_logs", "billing_type", "smallint", 0, false) + requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false) + requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false) // settings table should exist var settingsRegclass sql.NullString require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass)) require.True(t, settingsRegclass.Valid, "expected settings table to exist") + // security_secrets table should exist + var securitySecretsRegclass sql.NullString + require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.security_secrets')").Scan(&securitySecretsRegclass)) + require.True(t, securitySecretsRegclass.Valid, "expected security_secrets table to exist") + // user_allowed_groups table should exist var uagRegclass sql.NullString require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass)) diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go index 394d3a1a..dca0b612 100644 --- a/backend/internal/repository/openai_oauth_service.go +++ b/backend/internal/repository/openai_oauth_service.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "net/url" + "strings" "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" @@ -21,16 +22,23 @@ type openaiOAuthService struct { tokenURL string } -func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) { - client := createOpenAIReqClient(proxyURL) +func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { + client, err := createOpenAIReqClient(proxyURL) + if err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err) + } if redirectURI == "" { redirectURI = openai.DefaultRedirectURI } + clientID = strings.TrimSpace(clientID) + if clientID == "" { + clientID = openai.ClientID + } formData := url.Values{} formData.Set("grant_type", "authorization_code") - formData.Set("client_id", openai.ClientID) + formData.Set("client_id", clientID) formData.Set("code", code) formData.Set("redirect_uri", redirectURI) formData.Set("code_verifier", codeVerifier) @@ -56,12 +64,28 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie } func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { - client := createOpenAIReqClient(proxyURL) + return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "") +} + +func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + // 调用方应始终传入正确的 client_id;为兼容旧数据,未指定时默认使用 OpenAI ClientID + clientID = strings.TrimSpace(clientID) + if clientID == "" { + clientID = openai.ClientID + } + return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) +} + +func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) { + client, err := createOpenAIReqClient(proxyURL) + if err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err) + } formData := url.Values{} formData.Set("grant_type", "refresh_token") formData.Set("refresh_token", refreshToken) - formData.Set("client_id", openai.ClientID) + formData.Set("client_id", clientID) formData.Set("scope", openai.RefreshScopes) var tokenResp openai.TokenResponse @@ -84,7 +108,7 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro return &tokenResp, nil } -func createOpenAIReqClient(proxyURL string) *req.Client { +func createOpenAIReqClient(proxyURL string) (*req.Client, error) { return getSharedReqClient(reqClientOptions{ ProxyURL: proxyURL, Timeout: 120 * time.Second, diff --git a/backend/internal/repository/openai_oauth_service_test.go b/backend/internal/repository/openai_oauth_service_test.go index f9df08c8..44fa291b 100644 --- a/backend/internal/repository/openai_oauth_service_test.go +++ b/backend/internal/repository/openai_oauth_service_test.go @@ -81,7 +81,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() { _, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`) })) - resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "") + resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "", "") require.NoError(s.T(), err, "ExchangeCode") select { case msg := <-errCh: @@ -136,13 +136,84 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() { require.Equal(s.T(), "rt2", resp.RefreshToken) } +// TestRefreshToken_DefaultsToOpenAIClientID 验证未指定 client_id 时默认使用 OpenAI ClientID, +// 且只发送一次请求(不再盲猜多个 client_id)。 +func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() { + var seenClientIDs []string + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + clientID := r.PostForm.Get("client_id") + seenClientIDs = append(seenClientIDs, clientID) + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`) + })) + + resp, err := s.svc.RefreshToken(s.ctx, "rt", "") + require.NoError(s.T(), err, "RefreshToken") + require.Equal(s.T(), "at", resp.AccessToken) + // 只发送了一次请求,使用默认的 OpenAI ClientID + require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs) +} + +// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。 +func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() { + var seenClientIDs []string + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + clientID := r.PostForm.Get("client_id") + seenClientIDs = append(seenClientIDs, clientID) + if clientID == openai.SoraClientID { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`) + return + } + w.WriteHeader(http.StatusBadRequest) + })) + + resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID) + require.NoError(s.T(), err, "RefreshTokenWithClientID") + require.Equal(s.T(), "at-sora", resp.AccessToken) + require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs) +} + +func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() { + const customClientID = "custom-client-id" + var seenClientIDs []string + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + clientID := r.PostForm.Get("client_id") + seenClientIDs = append(seenClientIDs, clientID) + if clientID != customClientID { + w.WriteHeader(http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at-custom","refresh_token":"rt-custom","token_type":"bearer","expires_in":3600}`) + })) + + resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", customClientID) + require.NoError(s.T(), err, "RefreshTokenWithClientID") + require.Equal(s.T(), "at-custom", resp.AccessToken) + require.Equal(s.T(), "rt-custom", resp.RefreshToken) + require.Equal(s.T(), []string{customClientID}, seenClientIDs) +} + func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() { s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) _, _ = io.WriteString(w, "bad") })) - _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "") require.Error(s.T(), err) require.ErrorContains(s.T(), err, "status 400") require.ErrorContains(s.T(), err, "bad") @@ -152,7 +223,7 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() { s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) s.srv.Close() - _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "") require.Error(s.T(), err) require.ErrorContains(s.T(), err, "request failed") } @@ -169,7 +240,7 @@ func (s *OpenAIOAuthServiceSuite) TestContextCancel() { done := make(chan error, 1) go func() { - _, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "") + _, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "", "") done <- err }() @@ -195,7 +266,30 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() { _, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`) })) - _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "") + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "", "") + require.NoError(s.T(), err, "ExchangeCode") + select { + case msg := <-errCh: + require.Fail(s.T(), msg) + default: + } +} + +func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() { + wantClientID := openai.SoraClientID + errCh := make(chan string, 1) + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = r.ParseForm() + if got := r.PostForm.Get("client_id"); got != wantClientID { + errCh <- "client_id mismatch" + w.WriteHeader(http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`) + })) + + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", wantClientID) require.NoError(s.T(), err, "ExchangeCode") select { case msg := <-errCh: @@ -213,7 +307,7 @@ func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() { })) s.svc.tokenURL = s.srv.URL + "?x=1" - _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "") require.NoError(s.T(), err, "ExchangeCode") select { case <-s.received: @@ -229,7 +323,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() { _, _ = io.WriteString(w, "not-valid-json") })) - _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "") require.Error(s.T(), err, "expected error for invalid JSON response") } diff --git a/backend/internal/repository/ops_repo.go b/backend/internal/repository/ops_repo.go index b04154b7..989573f2 100644 --- a/backend/internal/repository/ops_repo.go +++ b/backend/internal/repository/ops_repo.go @@ -3,6 +3,7 @@ package repository import ( "context" "database/sql" + "encoding/json" "fmt" "strings" "time" @@ -55,6 +56,10 @@ INSERT INTO ops_error_logs ( upstream_error_message, upstream_error_detail, upstream_errors, + auth_latency_ms, + routing_latency_ms, + upstream_latency_ms, + response_latency_ms, time_to_first_token_ms, request_body, request_body_truncated, @@ -64,7 +69,7 @@ INSERT INTO ops_error_logs ( retry_count, created_at ) VALUES ( - $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34 + $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38 ) RETURNING id` var id int64 @@ -97,6 +102,10 @@ INSERT INTO ops_error_logs ( opsNullString(input.UpstreamErrorMessage), opsNullString(input.UpstreamErrorDetail), opsNullString(input.UpstreamErrorsJSON), + opsNullInt64(input.AuthLatencyMs), + opsNullInt64(input.RoutingLatencyMs), + opsNullInt64(input.UpstreamLatencyMs), + opsNullInt64(input.ResponseLatencyMs), opsNullInt64(input.TimeToFirstTokenMs), opsNullString(input.RequestBodyJSON), input.RequestBodyTruncated, @@ -930,6 +939,243 @@ WHERE id = $1` return err } +func (r *opsRepository) BatchInsertSystemLogs(ctx context.Context, inputs []*service.OpsInsertSystemLogInput) (int64, error) { + if r == nil || r.db == nil { + return 0, fmt.Errorf("nil ops repository") + } + if len(inputs) == 0 { + return 0, nil + } + + tx, err := r.db.BeginTx(ctx, nil) + if err != nil { + return 0, err + } + stmt, err := tx.PrepareContext(ctx, pq.CopyIn( + "ops_system_logs", + "created_at", + "level", + "component", + "message", + "request_id", + "client_request_id", + "user_id", + "account_id", + "platform", + "model", + "extra", + )) + if err != nil { + _ = tx.Rollback() + return 0, err + } + + var inserted int64 + for _, input := range inputs { + if input == nil { + continue + } + createdAt := input.CreatedAt + if createdAt.IsZero() { + createdAt = time.Now().UTC() + } + component := strings.TrimSpace(input.Component) + level := strings.ToLower(strings.TrimSpace(input.Level)) + message := strings.TrimSpace(input.Message) + if level == "" || message == "" { + continue + } + if component == "" { + component = "app" + } + extra := strings.TrimSpace(input.ExtraJSON) + if extra == "" { + extra = "{}" + } + if _, err := stmt.ExecContext( + ctx, + createdAt.UTC(), + level, + component, + message, + opsNullString(input.RequestID), + opsNullString(input.ClientRequestID), + opsNullInt64(input.UserID), + opsNullInt64(input.AccountID), + opsNullString(input.Platform), + opsNullString(input.Model), + extra, + ); err != nil { + _ = stmt.Close() + _ = tx.Rollback() + return inserted, err + } + inserted++ + } + + if _, err := stmt.ExecContext(ctx); err != nil { + _ = stmt.Close() + _ = tx.Rollback() + return inserted, err + } + if err := stmt.Close(); err != nil { + _ = tx.Rollback() + return inserted, err + } + if err := tx.Commit(); err != nil { + return inserted, err + } + return inserted, nil +} + +func (r *opsRepository) ListSystemLogs(ctx context.Context, filter *service.OpsSystemLogFilter) (*service.OpsSystemLogList, error) { + if r == nil || r.db == nil { + return nil, fmt.Errorf("nil ops repository") + } + if filter == nil { + filter = &service.OpsSystemLogFilter{} + } + + page := filter.Page + if page <= 0 { + page = 1 + } + pageSize := filter.PageSize + if pageSize <= 0 { + pageSize = 50 + } + if pageSize > 200 { + pageSize = 200 + } + + where, args, _ := buildOpsSystemLogsWhere(filter) + countSQL := "SELECT COUNT(*) FROM ops_system_logs l " + where + var total int + if err := r.db.QueryRowContext(ctx, countSQL, args...).Scan(&total); err != nil { + return nil, err + } + + offset := (page - 1) * pageSize + argsWithLimit := append(args, pageSize, offset) + query := ` +SELECT + l.id, + l.created_at, + l.level, + COALESCE(l.component, ''), + COALESCE(l.message, ''), + COALESCE(l.request_id, ''), + COALESCE(l.client_request_id, ''), + l.user_id, + l.account_id, + COALESCE(l.platform, ''), + COALESCE(l.model, ''), + COALESCE(l.extra::text, '{}') +FROM ops_system_logs l +` + where + ` +ORDER BY l.created_at DESC, l.id DESC +LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2) + + rows, err := r.db.QueryContext(ctx, query, argsWithLimit...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + logs := make([]*service.OpsSystemLog, 0, pageSize) + for rows.Next() { + item := &service.OpsSystemLog{} + var userID sql.NullInt64 + var accountID sql.NullInt64 + var extraRaw string + if err := rows.Scan( + &item.ID, + &item.CreatedAt, + &item.Level, + &item.Component, + &item.Message, + &item.RequestID, + &item.ClientRequestID, + &userID, + &accountID, + &item.Platform, + &item.Model, + &extraRaw, + ); err != nil { + return nil, err + } + if userID.Valid { + v := userID.Int64 + item.UserID = &v + } + if accountID.Valid { + v := accountID.Int64 + item.AccountID = &v + } + extraRaw = strings.TrimSpace(extraRaw) + if extraRaw != "" && extraRaw != "null" && extraRaw != "{}" { + extra := make(map[string]any) + if err := json.Unmarshal([]byte(extraRaw), &extra); err == nil { + item.Extra = extra + } + } + logs = append(logs, item) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return &service.OpsSystemLogList{ + Logs: logs, + Total: total, + Page: page, + PageSize: pageSize, + }, nil +} + +func (r *opsRepository) DeleteSystemLogs(ctx context.Context, filter *service.OpsSystemLogCleanupFilter) (int64, error) { + if r == nil || r.db == nil { + return 0, fmt.Errorf("nil ops repository") + } + if filter == nil { + filter = &service.OpsSystemLogCleanupFilter{} + } + + where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(filter) + if !hasConstraint { + return 0, fmt.Errorf("cleanup requires at least one filter condition") + } + + query := "DELETE FROM ops_system_logs l " + where + res, err := r.db.ExecContext(ctx, query, args...) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (r *opsRepository) InsertSystemLogCleanupAudit(ctx context.Context, input *service.OpsSystemLogCleanupAudit) error { + if r == nil || r.db == nil { + return fmt.Errorf("nil ops repository") + } + if input == nil { + return fmt.Errorf("nil input") + } + createdAt := input.CreatedAt + if createdAt.IsZero() { + createdAt = time.Now().UTC() + } + _, err := r.db.ExecContext(ctx, ` +INSERT INTO ops_system_log_cleanup_audits ( + created_at, + operator_id, + conditions, + deleted_rows +) VALUES ($1,$2,$3,$4) +`, createdAt.UTC(), input.OperatorID, input.Conditions, input.DeletedRows) + return err +} + func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { clauses := make([]string, 0, 12) args := make([]any, 0, 12) @@ -948,7 +1194,7 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { } // Keep list endpoints scoped to client errors unless explicitly filtering upstream phase. if phaseFilter != "upstream" { - clauses = append(clauses, "COALESCE(status_code, 0) >= 400") + clauses = append(clauses, "COALESCE(e.status_code, 0) >= 400") } if filter.StartTime != nil && !filter.StartTime.IsZero() { @@ -962,33 +1208,33 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { } if p := strings.TrimSpace(filter.Platform); p != "" { args = append(args, p) - clauses = append(clauses, "platform = $"+itoa(len(args))) + clauses = append(clauses, "e.platform = $"+itoa(len(args))) } if filter.GroupID != nil && *filter.GroupID > 0 { args = append(args, *filter.GroupID) - clauses = append(clauses, "group_id = $"+itoa(len(args))) + clauses = append(clauses, "e.group_id = $"+itoa(len(args))) } if filter.AccountID != nil && *filter.AccountID > 0 { args = append(args, *filter.AccountID) - clauses = append(clauses, "account_id = $"+itoa(len(args))) + clauses = append(clauses, "e.account_id = $"+itoa(len(args))) } if phase := phaseFilter; phase != "" { args = append(args, phase) - clauses = append(clauses, "error_phase = $"+itoa(len(args))) + clauses = append(clauses, "e.error_phase = $"+itoa(len(args))) } if filter != nil { if owner := strings.TrimSpace(strings.ToLower(filter.Owner)); owner != "" { args = append(args, owner) - clauses = append(clauses, "LOWER(COALESCE(error_owner,'')) = $"+itoa(len(args))) + clauses = append(clauses, "LOWER(COALESCE(e.error_owner,'')) = $"+itoa(len(args))) } if source := strings.TrimSpace(strings.ToLower(filter.Source)); source != "" { args = append(args, source) - clauses = append(clauses, "LOWER(COALESCE(error_source,'')) = $"+itoa(len(args))) + clauses = append(clauses, "LOWER(COALESCE(e.error_source,'')) = $"+itoa(len(args))) } } if resolvedFilter != nil { args = append(args, *resolvedFilter) - clauses = append(clauses, "COALESCE(resolved,false) = $"+itoa(len(args))) + clauses = append(clauses, "COALESCE(e.resolved,false) = $"+itoa(len(args))) } // View filter: errors vs excluded vs all. @@ -1000,51 +1246,140 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { } switch view { case "", "errors": - clauses = append(clauses, "COALESCE(is_business_limited,false) = false") + clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false") case "excluded": - clauses = append(clauses, "COALESCE(is_business_limited,false) = true") + clauses = append(clauses, "COALESCE(e.is_business_limited,false) = true") case "all": // no-op default: // treat unknown as default 'errors' - clauses = append(clauses, "COALESCE(is_business_limited,false) = false") + clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false") } if len(filter.StatusCodes) > 0 { args = append(args, pq.Array(filter.StatusCodes)) - clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")") + clauses = append(clauses, "COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+")") } else if filter.StatusCodesOther { // "Other" means: status codes not in the common list. known := []int{400, 401, 403, 404, 409, 422, 429, 500, 502, 503, 504, 529} args = append(args, pq.Array(known)) - clauses = append(clauses, "NOT (COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+"))") + clauses = append(clauses, "NOT (COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+"))") } // Exact correlation keys (preferred for request↔upstream linkage). if rid := strings.TrimSpace(filter.RequestID); rid != "" { args = append(args, rid) - clauses = append(clauses, "COALESCE(request_id,'') = $"+itoa(len(args))) + clauses = append(clauses, "COALESCE(e.request_id,'') = $"+itoa(len(args))) } if crid := strings.TrimSpace(filter.ClientRequestID); crid != "" { args = append(args, crid) - clauses = append(clauses, "COALESCE(client_request_id,'') = $"+itoa(len(args))) + clauses = append(clauses, "COALESCE(e.client_request_id,'') = $"+itoa(len(args))) } if q := strings.TrimSpace(filter.Query); q != "" { like := "%" + q + "%" args = append(args, like) n := itoa(len(args)) - clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")") + clauses = append(clauses, "(e.request_id ILIKE $"+n+" OR e.client_request_id ILIKE $"+n+" OR e.error_message ILIKE $"+n+")") } if userQuery := strings.TrimSpace(filter.UserQuery); userQuery != "" { like := "%" + userQuery + "%" args = append(args, like) n := itoa(len(args)) - clauses = append(clauses, "u.email ILIKE $"+n) + clauses = append(clauses, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $"+n+")") } return "WHERE " + strings.Join(clauses, " AND "), args } +func buildOpsSystemLogsWhere(filter *service.OpsSystemLogFilter) (string, []any, bool) { + clauses := make([]string, 0, 10) + args := make([]any, 0, 10) + clauses = append(clauses, "1=1") + hasConstraint := false + + if filter != nil && filter.StartTime != nil && !filter.StartTime.IsZero() { + args = append(args, filter.StartTime.UTC()) + clauses = append(clauses, "l.created_at >= $"+itoa(len(args))) + hasConstraint = true + } + if filter != nil && filter.EndTime != nil && !filter.EndTime.IsZero() { + args = append(args, filter.EndTime.UTC()) + clauses = append(clauses, "l.created_at < $"+itoa(len(args))) + hasConstraint = true + } + if filter != nil { + if v := strings.ToLower(strings.TrimSpace(filter.Level)); v != "" { + args = append(args, v) + clauses = append(clauses, "LOWER(COALESCE(l.level,'')) = $"+itoa(len(args))) + hasConstraint = true + } + if v := strings.TrimSpace(filter.Component); v != "" { + args = append(args, v) + clauses = append(clauses, "COALESCE(l.component,'') = $"+itoa(len(args))) + hasConstraint = true + } + if v := strings.TrimSpace(filter.RequestID); v != "" { + args = append(args, v) + clauses = append(clauses, "COALESCE(l.request_id,'') = $"+itoa(len(args))) + hasConstraint = true + } + if v := strings.TrimSpace(filter.ClientRequestID); v != "" { + args = append(args, v) + clauses = append(clauses, "COALESCE(l.client_request_id,'') = $"+itoa(len(args))) + hasConstraint = true + } + if filter.UserID != nil && *filter.UserID > 0 { + args = append(args, *filter.UserID) + clauses = append(clauses, "l.user_id = $"+itoa(len(args))) + hasConstraint = true + } + if filter.AccountID != nil && *filter.AccountID > 0 { + args = append(args, *filter.AccountID) + clauses = append(clauses, "l.account_id = $"+itoa(len(args))) + hasConstraint = true + } + if v := strings.TrimSpace(filter.Platform); v != "" { + args = append(args, v) + clauses = append(clauses, "COALESCE(l.platform,'') = $"+itoa(len(args))) + hasConstraint = true + } + if v := strings.TrimSpace(filter.Model); v != "" { + args = append(args, v) + clauses = append(clauses, "COALESCE(l.model,'') = $"+itoa(len(args))) + hasConstraint = true + } + if v := strings.TrimSpace(filter.Query); v != "" { + like := "%" + v + "%" + args = append(args, like) + n := itoa(len(args)) + clauses = append(clauses, "(l.message ILIKE $"+n+" OR COALESCE(l.request_id,'') ILIKE $"+n+" OR COALESCE(l.client_request_id,'') ILIKE $"+n+" OR COALESCE(l.extra::text,'') ILIKE $"+n+")") + hasConstraint = true + } + } + + return "WHERE " + strings.Join(clauses, " AND "), args, hasConstraint +} + +func buildOpsSystemLogsCleanupWhere(filter *service.OpsSystemLogCleanupFilter) (string, []any, bool) { + if filter == nil { + filter = &service.OpsSystemLogCleanupFilter{} + } + listFilter := &service.OpsSystemLogFilter{ + StartTime: filter.StartTime, + EndTime: filter.EndTime, + Level: filter.Level, + Component: filter.Component, + RequestID: filter.RequestID, + ClientRequestID: filter.ClientRequestID, + UserID: filter.UserID, + AccountID: filter.AccountID, + Platform: filter.Platform, + Model: filter.Model, + Query: filter.Query, + } + return buildOpsSystemLogsWhere(listFilter) +} + // Helpers for nullable args func opsNullString(v any) any { switch s := v.(type) { diff --git a/backend/internal/repository/ops_repo_dashboard.go b/backend/internal/repository/ops_repo_dashboard.go index 85791a9a..b43d6706 100644 --- a/backend/internal/repository/ops_repo_dashboard.go +++ b/backend/internal/repository/ops_repo_dashboard.go @@ -12,6 +12,11 @@ import ( "github.com/Wei-Shaw/sub2api/internal/service" ) +const ( + opsRawLatencyQueryTimeout = 2 * time.Second + opsRawPeakQueryTimeout = 1500 * time.Millisecond +) + func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) { if r == nil || r.db == nil { return nil, fmt.Errorf("nil ops repository") @@ -45,15 +50,24 @@ func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *servic func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) { start := filter.StartTime.UTC() end := filter.EndTime.UTC() + degraded := false successCount, tokenConsumed, err := r.queryUsageCounts(ctx, filter, start, end) if err != nil { return nil, err } - duration, ttft, err := r.queryUsageLatency(ctx, filter, start, end) + latencyCtx, cancelLatency := context.WithTimeout(ctx, opsRawLatencyQueryTimeout) + duration, ttft, err := r.queryUsageLatency(latencyCtx, filter, start, end) + cancelLatency() if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + degraded = true + duration = service.OpsPercentiles{} + ttft = service.OpsPercentiles{} + } else { + return nil, err + } } errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end) @@ -75,20 +89,40 @@ func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *ser qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end) if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + degraded = true + } else { + return nil, err + } } - qpsPeak, err := r.queryPeakQPS(ctx, filter, start, end) + peakCtx, cancelPeak := context.WithTimeout(ctx, opsRawPeakQueryTimeout) + qpsPeak, tpsPeak, err := r.queryPeakRates(peakCtx, filter, start, end) + cancelPeak() if err != nil { - return nil, err - } - tpsPeak, err := r.queryPeakTPS(ctx, filter, start, end) - if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + degraded = true + } else { + return nil, err + } } qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds) tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds) + if degraded { + if qpsCurrent <= 0 { + qpsCurrent = qpsAvg + } + if tpsCurrent <= 0 { + tpsCurrent = tpsAvg + } + if qpsPeak <= 0 { + qpsPeak = roundTo1DP(math.Max(qpsCurrent, qpsAvg)) + } + if tpsPeak <= 0 { + tpsPeak = roundTo1DP(math.Max(tpsCurrent, tpsAvg)) + } + } return &service.OpsDashboardOverview{ StartTime: start, @@ -230,26 +264,45 @@ func (r *opsRepository) getDashboardOverviewPreaggregated(ctx context.Context, f sla := safeDivideFloat64(float64(successCount), float64(requestCountSLA)) errorRate := safeDivideFloat64(float64(errorCountSLA), float64(requestCountSLA)) upstreamErrorRate := safeDivideFloat64(float64(upstreamExcl), float64(requestCountSLA)) + degraded := false // Keep "current" rates as raw, to preserve realtime semantics. qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end) if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + degraded = true + } else { + return nil, err + } } - // NOTE: peak still uses raw logs (minute granularity). This is typically cheaper than percentile_cont - // and keeps semantics consistent across modes. - qpsPeak, err := r.queryPeakQPS(ctx, filter, start, end) + peakCtx, cancelPeak := context.WithTimeout(ctx, opsRawPeakQueryTimeout) + qpsPeak, tpsPeak, err := r.queryPeakRates(peakCtx, filter, start, end) + cancelPeak() if err != nil { - return nil, err - } - tpsPeak, err := r.queryPeakTPS(ctx, filter, start, end) - if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + degraded = true + } else { + return nil, err + } } qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds) tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds) + if degraded { + if qpsCurrent <= 0 { + qpsCurrent = qpsAvg + } + if tpsCurrent <= 0 { + tpsCurrent = tpsAvg + } + if qpsPeak <= 0 { + qpsPeak = roundTo1DP(math.Max(qpsCurrent, qpsAvg)) + } + if tpsPeak <= 0 { + tpsPeak = roundTo1DP(math.Max(tpsCurrent, tpsAvg)) + } + } return &service.OpsDashboardOverview{ StartTime: start, @@ -577,9 +630,16 @@ func (r *opsRepository) queryRawPartial(ctx context.Context, filter *service.Ops return nil, err } - duration, ttft, err := r.queryUsageLatency(ctx, filter, start, end) + latencyCtx, cancelLatency := context.WithTimeout(ctx, opsRawLatencyQueryTimeout) + duration, ttft, err := r.queryUsageLatency(latencyCtx, filter, start, end) + cancelLatency() if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + duration = service.OpsPercentiles{} + ttft = service.OpsPercentiles{} + } else { + return nil, err + } } errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end) @@ -735,68 +795,56 @@ FROM usage_logs ul } func (r *opsRepository) queryUsageLatency(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (duration service.OpsPercentiles, ttft service.OpsPercentiles, err error) { - { - join, where, args, _ := buildUsageWhere(filter, start, end, 1) - q := ` + join, where, args, _ := buildUsageWhere(filter, start, end, 1) + q := ` SELECT - percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) AS p50, - percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) AS p90, - percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) AS p95, - percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) AS p99, - AVG(duration_ms) AS avg_ms, - MAX(duration_ms) AS max_ms + percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p50, + percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p90, + percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p95, + percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p99, + AVG(duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_avg, + MAX(duration_ms) AS duration_max, + percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p50, + percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p90, + percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p95, + percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p99, + AVG(first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_avg, + MAX(first_token_ms) AS ttft_max FROM usage_logs ul ` + join + ` -` + where + ` -AND duration_ms IS NOT NULL` +` + where - var p50, p90, p95, p99 sql.NullFloat64 - var avg sql.NullFloat64 - var max sql.NullInt64 - if err := r.db.QueryRowContext(ctx, q, args...).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil { - return service.OpsPercentiles{}, service.OpsPercentiles{}, err - } - duration.P50 = floatToIntPtr(p50) - duration.P90 = floatToIntPtr(p90) - duration.P95 = floatToIntPtr(p95) - duration.P99 = floatToIntPtr(p99) - duration.Avg = floatToIntPtr(avg) - if max.Valid { - v := int(max.Int64) - duration.Max = &v - } + var dP50, dP90, dP95, dP99 sql.NullFloat64 + var dAvg sql.NullFloat64 + var dMax sql.NullInt64 + var tP50, tP90, tP95, tP99 sql.NullFloat64 + var tAvg sql.NullFloat64 + var tMax sql.NullInt64 + if err := r.db.QueryRowContext(ctx, q, args...).Scan( + &dP50, &dP90, &dP95, &dP99, &dAvg, &dMax, + &tP50, &tP90, &tP95, &tP99, &tAvg, &tMax, + ); err != nil { + return service.OpsPercentiles{}, service.OpsPercentiles{}, err } - { - join, where, args, _ := buildUsageWhere(filter, start, end, 1) - q := ` -SELECT - percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) AS p50, - percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) AS p90, - percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) AS p95, - percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) AS p99, - AVG(first_token_ms) AS avg_ms, - MAX(first_token_ms) AS max_ms -FROM usage_logs ul -` + join + ` -` + where + ` -AND first_token_ms IS NOT NULL` + duration.P50 = floatToIntPtr(dP50) + duration.P90 = floatToIntPtr(dP90) + duration.P95 = floatToIntPtr(dP95) + duration.P99 = floatToIntPtr(dP99) + duration.Avg = floatToIntPtr(dAvg) + if dMax.Valid { + v := int(dMax.Int64) + duration.Max = &v + } - var p50, p90, p95, p99 sql.NullFloat64 - var avg sql.NullFloat64 - var max sql.NullInt64 - if err := r.db.QueryRowContext(ctx, q, args...).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil { - return service.OpsPercentiles{}, service.OpsPercentiles{}, err - } - ttft.P50 = floatToIntPtr(p50) - ttft.P90 = floatToIntPtr(p90) - ttft.P95 = floatToIntPtr(p95) - ttft.P99 = floatToIntPtr(p99) - ttft.Avg = floatToIntPtr(avg) - if max.Valid { - v := int(max.Int64) - ttft.Max = &v - } + ttft.P50 = floatToIntPtr(tP50) + ttft.P90 = floatToIntPtr(tP90) + ttft.P95 = floatToIntPtr(tP95) + ttft.P99 = floatToIntPtr(tP99) + ttft.Avg = floatToIntPtr(tAvg) + if tMax.Valid { + v := int(tMax.Int64) + ttft.Max = &v } return duration, ttft, nil @@ -854,20 +902,23 @@ func (r *opsRepository) queryCurrentRates(ctx context.Context, filter *service.O return qpsCurrent, tpsCurrent, nil } -func (r *opsRepository) queryPeakQPS(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (float64, error) { +func (r *opsRepository) queryPeakRates(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (qpsPeak float64, tpsPeak float64, err error) { usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1) errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next) q := ` WITH usage_buckets AS ( - SELECT date_trunc('minute', ul.created_at) AS bucket, COUNT(*) AS cnt + SELECT + date_trunc('minute', ul.created_at) AS bucket, + COUNT(*) AS req_cnt, + COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_cnt FROM usage_logs ul ` + usageJoin + ` ` + usageWhere + ` GROUP BY 1 ), error_buckets AS ( - SELECT date_trunc('minute', created_at) AS bucket, COUNT(*) AS cnt + SELECT date_trunc('minute', created_at) AS bucket, COUNT(*) AS err_cnt FROM ops_error_logs ` + errorWhere + ` AND COALESCE(status_code, 0) >= 400 @@ -875,47 +926,33 @@ error_buckets AS ( ), combined AS ( SELECT COALESCE(u.bucket, e.bucket) AS bucket, - COALESCE(u.cnt, 0) + COALESCE(e.cnt, 0) AS total + COALESCE(u.req_cnt, 0) + COALESCE(e.err_cnt, 0) AS total_req, + COALESCE(u.token_cnt, 0) AS total_tokens FROM usage_buckets u FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket ) -SELECT COALESCE(MAX(total), 0) FROM combined` +SELECT + COALESCE(MAX(total_req), 0) AS max_req_per_min, + COALESCE(MAX(total_tokens), 0) AS max_tokens_per_min +FROM combined` args := append(usageArgs, errorArgs...) - var maxPerMinute sql.NullInt64 - if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxPerMinute); err != nil { - return 0, err + var maxReqPerMinute, maxTokensPerMinute sql.NullInt64 + if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxReqPerMinute, &maxTokensPerMinute); err != nil { + return 0, 0, err } - if !maxPerMinute.Valid || maxPerMinute.Int64 <= 0 { - return 0, nil + if maxReqPerMinute.Valid && maxReqPerMinute.Int64 > 0 { + qpsPeak = roundTo1DP(float64(maxReqPerMinute.Int64) / 60.0) } - return roundTo1DP(float64(maxPerMinute.Int64) / 60.0), nil + if maxTokensPerMinute.Valid && maxTokensPerMinute.Int64 > 0 { + tpsPeak = roundTo1DP(float64(maxTokensPerMinute.Int64) / 60.0) + } + return qpsPeak, tpsPeak, nil } -func (r *opsRepository) queryPeakTPS(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (float64, error) { - join, where, args, _ := buildUsageWhere(filter, start, end, 1) - - q := ` -SELECT COALESCE(MAX(tokens_per_min), 0) -FROM ( - SELECT - date_trunc('minute', ul.created_at) AS bucket, - COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS tokens_per_min - FROM usage_logs ul - ` + join + ` - ` + where + ` - GROUP BY 1 -) t` - - var maxPerMinute sql.NullInt64 - if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxPerMinute); err != nil { - return 0, err - } - if !maxPerMinute.Valid || maxPerMinute.Int64 <= 0 { - return 0, nil - } - return roundTo1DP(float64(maxPerMinute.Int64) / 60.0), nil +func isQueryTimeoutErr(err error) bool { + return errors.Is(err, context.DeadlineExceeded) } func buildUsageWhere(filter *service.OpsDashboardFilter, start, end time.Time, startIndex int) (join string, where string, args []any, nextIndex int) { diff --git a/backend/internal/repository/ops_repo_dashboard_timeout_test.go b/backend/internal/repository/ops_repo_dashboard_timeout_test.go new file mode 100644 index 00000000..76332ca0 --- /dev/null +++ b/backend/internal/repository/ops_repo_dashboard_timeout_test.go @@ -0,0 +1,22 @@ +package repository + +import ( + "context" + "fmt" + "testing" +) + +func TestIsQueryTimeoutErr(t *testing.T) { + if !isQueryTimeoutErr(context.DeadlineExceeded) { + t.Fatalf("context.DeadlineExceeded should be treated as query timeout") + } + if !isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.DeadlineExceeded)) { + t.Fatalf("wrapped context.DeadlineExceeded should be treated as query timeout") + } + if isQueryTimeoutErr(context.Canceled) { + t.Fatalf("context.Canceled should not be treated as query timeout") + } + if isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.Canceled)) { + t.Fatalf("wrapped context.Canceled should not be treated as query timeout") + } +} diff --git a/backend/internal/repository/ops_repo_error_where_test.go b/backend/internal/repository/ops_repo_error_where_test.go new file mode 100644 index 00000000..9ab1a89a --- /dev/null +++ b/backend/internal/repository/ops_repo_error_where_test.go @@ -0,0 +1,48 @@ +package repository + +import ( + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func TestBuildOpsErrorLogsWhere_QueryUsesQualifiedColumns(t *testing.T) { + filter := &service.OpsErrorLogFilter{ + Query: "ACCESS_DENIED", + } + + where, args := buildOpsErrorLogsWhere(filter) + if where == "" { + t.Fatalf("where should not be empty") + } + if len(args) != 1 { + t.Fatalf("args len = %d, want 1", len(args)) + } + if !strings.Contains(where, "e.request_id ILIKE $") { + t.Fatalf("where should include qualified request_id condition: %s", where) + } + if !strings.Contains(where, "e.client_request_id ILIKE $") { + t.Fatalf("where should include qualified client_request_id condition: %s", where) + } + if !strings.Contains(where, "e.error_message ILIKE $") { + t.Fatalf("where should include qualified error_message condition: %s", where) + } +} + +func TestBuildOpsErrorLogsWhere_UserQueryUsesExistsSubquery(t *testing.T) { + filter := &service.OpsErrorLogFilter{ + UserQuery: "admin@", + } + + where, args := buildOpsErrorLogsWhere(filter) + if where == "" { + t.Fatalf("where should not be empty") + } + if len(args) != 1 { + t.Fatalf("args len = %d, want 1", len(args)) + } + if !strings.Contains(where, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $") { + t.Fatalf("where should include EXISTS user email condition: %s", where) + } +} diff --git a/backend/internal/repository/ops_repo_latency_histogram_buckets.go b/backend/internal/repository/ops_repo_latency_histogram_buckets.go index cd5bed37..e56903f1 100644 --- a/backend/internal/repository/ops_repo_latency_histogram_buckets.go +++ b/backend/internal/repository/ops_repo_latency_histogram_buckets.go @@ -35,12 +35,12 @@ func latencyHistogramRangeCaseExpr(column string) string { if b.upperMs <= 0 { continue } - _, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN '%s'\n", column, b.upperMs, b.label)) + fmt.Fprintf(&sb, "\tWHEN %s < %d THEN '%s'\n", column, b.upperMs, b.label) } // Default bucket. last := latencyHistogramBuckets[len(latencyHistogramBuckets)-1] - _, _ = sb.WriteString(fmt.Sprintf("\tELSE '%s'\n", last.label)) + fmt.Fprintf(&sb, "\tELSE '%s'\n", last.label) _, _ = sb.WriteString("END") return sb.String() } @@ -54,11 +54,11 @@ func latencyHistogramRangeOrderCaseExpr(column string) string { if b.upperMs <= 0 { continue } - _, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN %d\n", column, b.upperMs, order)) + fmt.Fprintf(&sb, "\tWHEN %s < %d THEN %d\n", column, b.upperMs, order) order++ } - _, _ = sb.WriteString(fmt.Sprintf("\tELSE %d\n", order)) + fmt.Fprintf(&sb, "\tELSE %d\n", order) _, _ = sb.WriteString("END") return sb.String() } diff --git a/backend/internal/repository/ops_repo_openai_token_stats.go b/backend/internal/repository/ops_repo_openai_token_stats.go new file mode 100644 index 00000000..6aea416e --- /dev/null +++ b/backend/internal/repository/ops_repo_openai_token_stats.go @@ -0,0 +1,145 @@ +package repository + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func (r *opsRepository) GetOpenAITokenStats(ctx context.Context, filter *service.OpsOpenAITokenStatsFilter) (*service.OpsOpenAITokenStatsResponse, error) { + if r == nil || r.db == nil { + return nil, fmt.Errorf("nil ops repository") + } + if filter == nil { + return nil, fmt.Errorf("nil filter") + } + if filter.StartTime.IsZero() || filter.EndTime.IsZero() { + return nil, fmt.Errorf("start_time/end_time required") + } + // 允许 start_time == end_time(结果为空),与 service 层校验口径保持一致。 + if filter.StartTime.After(filter.EndTime) { + return nil, fmt.Errorf("start_time must be <= end_time") + } + + dashboardFilter := &service.OpsDashboardFilter{ + StartTime: filter.StartTime.UTC(), + EndTime: filter.EndTime.UTC(), + Platform: strings.TrimSpace(strings.ToLower(filter.Platform)), + GroupID: filter.GroupID, + } + + join, where, baseArgs, next := buildUsageWhere(dashboardFilter, dashboardFilter.StartTime, dashboardFilter.EndTime, 1) + where += " AND ul.model LIKE 'gpt%'" + + baseCTE := ` +WITH stats AS ( + SELECT + ul.model AS model, + COUNT(*)::bigint AS request_count, + ROUND( + AVG( + CASE + WHEN ul.duration_ms > 0 AND ul.output_tokens > 0 + THEN ul.output_tokens * 1000.0 / ul.duration_ms + END + )::numeric, + 2 + )::float8 AS avg_tokens_per_sec, + ROUND(AVG(ul.first_token_ms)::numeric, 2)::float8 AS avg_first_token_ms, + COALESCE(SUM(ul.output_tokens), 0)::bigint AS total_output_tokens, + COALESCE(ROUND(AVG(ul.duration_ms)::numeric, 0), 0)::bigint AS avg_duration_ms, + COUNT(CASE WHEN ul.first_token_ms IS NOT NULL THEN 1 END)::bigint AS requests_with_first_token + FROM usage_logs ul + ` + join + ` + ` + where + ` + GROUP BY ul.model +) +` + + countSQL := baseCTE + `SELECT COUNT(*) FROM stats` + var total int64 + if err := r.db.QueryRowContext(ctx, countSQL, baseArgs...).Scan(&total); err != nil { + return nil, err + } + + querySQL := baseCTE + ` +SELECT + model, + request_count, + avg_tokens_per_sec, + avg_first_token_ms, + total_output_tokens, + avg_duration_ms, + requests_with_first_token +FROM stats +ORDER BY request_count DESC, model ASC` + + args := make([]any, 0, len(baseArgs)+2) + args = append(args, baseArgs...) + + if filter.IsTopNMode() { + querySQL += fmt.Sprintf("\nLIMIT $%d", next) + args = append(args, filter.TopN) + } else { + offset := (filter.Page - 1) * filter.PageSize + querySQL += fmt.Sprintf("\nLIMIT $%d OFFSET $%d", next, next+1) + args = append(args, filter.PageSize, offset) + } + + rows, err := r.db.QueryContext(ctx, querySQL, args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + items := make([]*service.OpsOpenAITokenStatsItem, 0, 32) + for rows.Next() { + item := &service.OpsOpenAITokenStatsItem{} + var avgTPS sql.NullFloat64 + var avgFirstToken sql.NullFloat64 + if err := rows.Scan( + &item.Model, + &item.RequestCount, + &avgTPS, + &avgFirstToken, + &item.TotalOutputTokens, + &item.AvgDurationMs, + &item.RequestsWithFirstToken, + ); err != nil { + return nil, err + } + if avgTPS.Valid { + v := avgTPS.Float64 + item.AvgTokensPerSec = &v + } + if avgFirstToken.Valid { + v := avgFirstToken.Float64 + item.AvgFirstTokenMs = &v + } + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, err + } + + resp := &service.OpsOpenAITokenStatsResponse{ + TimeRange: strings.TrimSpace(filter.TimeRange), + StartTime: dashboardFilter.StartTime, + EndTime: dashboardFilter.EndTime, + Platform: dashboardFilter.Platform, + GroupID: dashboardFilter.GroupID, + Items: items, + Total: total, + } + if filter.IsTopNMode() { + topN := filter.TopN + resp.TopN = &topN + } else { + resp.Page = filter.Page + resp.PageSize = filter.PageSize + } + return resp, nil +} diff --git a/backend/internal/repository/ops_repo_openai_token_stats_test.go b/backend/internal/repository/ops_repo_openai_token_stats_test.go new file mode 100644 index 00000000..bb01d820 --- /dev/null +++ b/backend/internal/repository/ops_repo_openai_token_stats_test.go @@ -0,0 +1,156 @@ +package repository + +import ( + "context" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestOpsRepositoryGetOpenAITokenStats_PaginationMode(t *testing.T) { + db, mock := newSQLMock(t) + repo := &opsRepository{db: db} + + start := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + groupID := int64(9) + + filter := &service.OpsOpenAITokenStatsFilter{ + TimeRange: "1d", + StartTime: start, + EndTime: end, + Platform: " OpenAI ", + GroupID: &groupID, + Page: 2, + PageSize: 10, + } + + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM stats`). + WithArgs(start, end, groupID, "openai"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(3))) + + rows := sqlmock.NewRows([]string{ + "model", + "request_count", + "avg_tokens_per_sec", + "avg_first_token_ms", + "total_output_tokens", + "avg_duration_ms", + "requests_with_first_token", + }). + AddRow("gpt-4o-mini", int64(20), 21.56, 120.34, int64(3000), int64(850), int64(18)). + AddRow("gpt-4.1", int64(20), 10.2, 240.0, int64(2500), int64(900), int64(20)) + + mock.ExpectQuery(`ORDER BY request_count DESC, model ASC\s+LIMIT \$5 OFFSET \$6`). + WithArgs(start, end, groupID, "openai", 10, 10). + WillReturnRows(rows) + + resp, err := repo.GetOpenAITokenStats(context.Background(), filter) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, int64(3), resp.Total) + require.Equal(t, 2, resp.Page) + require.Equal(t, 10, resp.PageSize) + require.Nil(t, resp.TopN) + require.Equal(t, "openai", resp.Platform) + require.NotNil(t, resp.GroupID) + require.Equal(t, groupID, *resp.GroupID) + require.Len(t, resp.Items, 2) + require.Equal(t, "gpt-4o-mini", resp.Items[0].Model) + require.NotNil(t, resp.Items[0].AvgTokensPerSec) + require.InDelta(t, 21.56, *resp.Items[0].AvgTokensPerSec, 0.0001) + require.NotNil(t, resp.Items[0].AvgFirstTokenMs) + require.InDelta(t, 120.34, *resp.Items[0].AvgFirstTokenMs, 0.0001) + + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestOpsRepositoryGetOpenAITokenStats_TopNMode(t *testing.T) { + db, mock := newSQLMock(t) + repo := &opsRepository{db: db} + + start := time.Date(2026, 1, 1, 10, 0, 0, 0, time.UTC) + end := start.Add(time.Hour) + filter := &service.OpsOpenAITokenStatsFilter{ + TimeRange: "1h", + StartTime: start, + EndTime: end, + TopN: 5, + } + + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM stats`). + WithArgs(start, end). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1))) + + rows := sqlmock.NewRows([]string{ + "model", + "request_count", + "avg_tokens_per_sec", + "avg_first_token_ms", + "total_output_tokens", + "avg_duration_ms", + "requests_with_first_token", + }). + AddRow("gpt-4o", int64(5), nil, nil, int64(0), int64(0), int64(0)) + + mock.ExpectQuery(`ORDER BY request_count DESC, model ASC\s+LIMIT \$3`). + WithArgs(start, end, 5). + WillReturnRows(rows) + + resp, err := repo.GetOpenAITokenStats(context.Background(), filter) + require.NoError(t, err) + require.NotNil(t, resp) + require.NotNil(t, resp.TopN) + require.Equal(t, 5, *resp.TopN) + require.Equal(t, 0, resp.Page) + require.Equal(t, 0, resp.PageSize) + require.Len(t, resp.Items, 1) + require.Nil(t, resp.Items[0].AvgTokensPerSec) + require.Nil(t, resp.Items[0].AvgFirstTokenMs) + + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestOpsRepositoryGetOpenAITokenStats_EmptyResult(t *testing.T) { + db, mock := newSQLMock(t) + repo := &opsRepository{db: db} + + start := time.Date(2026, 1, 2, 0, 0, 0, 0, time.UTC) + end := start.Add(30 * time.Minute) + filter := &service.OpsOpenAITokenStatsFilter{ + TimeRange: "30m", + StartTime: start, + EndTime: end, + Page: 1, + PageSize: 20, + } + + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM stats`). + WithArgs(start, end). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0))) + + mock.ExpectQuery(`ORDER BY request_count DESC, model ASC\s+LIMIT \$3 OFFSET \$4`). + WithArgs(start, end, 20, 0). + WillReturnRows(sqlmock.NewRows([]string{ + "model", + "request_count", + "avg_tokens_per_sec", + "avg_first_token_ms", + "total_output_tokens", + "avg_duration_ms", + "requests_with_first_token", + })) + + resp, err := repo.GetOpenAITokenStats(context.Background(), filter) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, int64(0), resp.Total) + require.Len(t, resp.Items, 0) + require.Equal(t, 1, resp.Page) + require.Equal(t, 20, resp.PageSize) + + require.NoError(t, mock.ExpectationsWereMet()) +} diff --git a/backend/internal/repository/ops_repo_system_logs_test.go b/backend/internal/repository/ops_repo_system_logs_test.go new file mode 100644 index 00000000..c3524fe4 --- /dev/null +++ b/backend/internal/repository/ops_repo_system_logs_test.go @@ -0,0 +1,86 @@ +package repository + +import ( + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func TestBuildOpsSystemLogsWhere_WithClientRequestIDAndUserID(t *testing.T) { + start := time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC) + end := time.Date(2026, 2, 2, 0, 0, 0, 0, time.UTC) + userID := int64(12) + accountID := int64(34) + + filter := &service.OpsSystemLogFilter{ + StartTime: &start, + EndTime: &end, + Level: "warn", + Component: "http.access", + RequestID: "req-1", + ClientRequestID: "creq-1", + UserID: &userID, + AccountID: &accountID, + Platform: "openai", + Model: "gpt-5", + Query: "timeout", + } + + where, args, hasConstraint := buildOpsSystemLogsWhere(filter) + if !hasConstraint { + t.Fatalf("expected hasConstraint=true") + } + if where == "" { + t.Fatalf("where should not be empty") + } + if len(args) != 11 { + t.Fatalf("args len = %d, want 11", len(args)) + } + if !contains(where, "COALESCE(l.client_request_id,'') = $") { + t.Fatalf("where should include client_request_id condition: %s", where) + } + if !contains(where, "l.user_id = $") { + t.Fatalf("where should include user_id condition: %s", where) + } +} + +func TestBuildOpsSystemLogsCleanupWhere_RequireConstraint(t *testing.T) { + where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(&service.OpsSystemLogCleanupFilter{}) + if hasConstraint { + t.Fatalf("expected hasConstraint=false") + } + if where == "" { + t.Fatalf("where should not be empty") + } + if len(args) != 0 { + t.Fatalf("args len = %d, want 0", len(args)) + } +} + +func TestBuildOpsSystemLogsCleanupWhere_WithClientRequestIDAndUserID(t *testing.T) { + userID := int64(9) + filter := &service.OpsSystemLogCleanupFilter{ + ClientRequestID: "creq-9", + UserID: &userID, + } + + where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(filter) + if !hasConstraint { + t.Fatalf("expected hasConstraint=true") + } + if len(args) != 2 { + t.Fatalf("args len = %d, want 2", len(args)) + } + if !contains(where, "COALESCE(l.client_request_id,'') = $") { + t.Fatalf("where should include client_request_id condition: %s", where) + } + if !contains(where, "l.user_id = $") { + t.Fatalf("where should include user_id condition: %s", where) + } +} + +func contains(s string, sub string) bool { + return strings.Contains(s, sub) +} diff --git a/backend/internal/repository/pricing_service.go b/backend/internal/repository/pricing_service.go index 07d796b8..ee8e1749 100644 --- a/backend/internal/repository/pricing_service.go +++ b/backend/internal/repository/pricing_service.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "log/slog" "net/http" "strings" "time" @@ -16,14 +17,37 @@ type pricingRemoteClient struct { httpClient *http.Client } +// pricingRemoteClientError 代理初始化失败时的错误占位客户端 +// 所有请求直接返回初始化错误,禁止回退到直连 +type pricingRemoteClientError struct { + err error +} + +func (c *pricingRemoteClientError) FetchPricingJSON(_ context.Context, _ string) ([]byte, error) { + return nil, c.err +} + +func (c *pricingRemoteClientError) FetchHashText(_ context.Context, _ string) (string, error) { + return "", c.err +} + // NewPricingRemoteClient 创建定价数据远程客户端 // proxyURL 为空时直连,支持 http/https/socks5/socks5h 协议 -func NewPricingRemoteClient(proxyURL string) service.PricingRemoteClient { +// 代理配置失败时行为由 allowDirectOnProxyError 控制: +// - false(默认):返回错误占位客户端,禁止回退到直连 +// - true:回退到直连(仅限管理员显式开启) +func NewPricingRemoteClient(proxyURL string, allowDirectOnProxyError bool) service.PricingRemoteClient { + // 安全说明:httpclient.GetClient 的错误链(url.Parse / proxyutil)不含明文代理凭据, + // 但仍通过 slog 仅在服务端日志记录,不会暴露给 HTTP 响应。 sharedClient, err := httpclient.GetClient(httpclient.Options{ Timeout: 30 * time.Second, ProxyURL: proxyURL, }) if err != nil { + if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError { + slog.Warn("proxy client init failed, all requests will fail", "service", "pricing", "error", err) + return &pricingRemoteClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)} + } sharedClient = &http.Client{Timeout: 30 * time.Second} } return &pricingRemoteClient{ diff --git a/backend/internal/repository/pricing_service_test.go b/backend/internal/repository/pricing_service_test.go index 6ea11211..ef2f214b 100644 --- a/backend/internal/repository/pricing_service_test.go +++ b/backend/internal/repository/pricing_service_test.go @@ -19,7 +19,7 @@ type PricingServiceSuite struct { func (s *PricingServiceSuite) SetupTest() { s.ctx = context.Background() - client, ok := NewPricingRemoteClient("").(*pricingRemoteClient) + client, ok := NewPricingRemoteClient("", false).(*pricingRemoteClient) require.True(s.T(), ok, "type assertion failed") s.client = client } @@ -140,6 +140,22 @@ func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() { require.Error(s.T(), err) } +func TestNewPricingRemoteClient_InvalidProxy_NoFallback(t *testing.T) { + client := NewPricingRemoteClient("://bad", false) + _, ok := client.(*pricingRemoteClientError) + require.True(t, ok, "should return error client when proxy is invalid and fallback disabled") + + _, err := client.FetchPricingJSON(context.Background(), "http://example.com") + require.Error(t, err) + require.Contains(t, err.Error(), "proxy client init failed") +} + +func TestNewPricingRemoteClient_InvalidProxy_WithFallback(t *testing.T) { + client := NewPricingRemoteClient("://bad", true) + _, ok := client.(*pricingRemoteClient) + require.True(t, ok, "should fallback to direct client when allowed") +} + func TestPricingServiceSuite(t *testing.T) { suite.Run(t, new(PricingServiceSuite)) } diff --git a/backend/internal/repository/promo_code_repo.go b/backend/internal/repository/promo_code_repo.go index 98b422e0..95ce687a 100644 --- a/backend/internal/repository/promo_code_repo.go +++ b/backend/internal/repository/promo_code_repo.go @@ -132,7 +132,7 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina q = q.Where(promocode.CodeContainsFold(search)) } - total, err := q.Count(ctx) + total, err := q.Clone().Count(ctx) if err != nil { return nil, nil, err } @@ -187,7 +187,7 @@ func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCo q := r.client.PromoCodeUsage.Query(). Where(promocodeusage.PromoCodeIDEQ(promoCodeID)) - total, err := q.Count(ctx) + total, err := q.Clone().Count(ctx) if err != nil { return nil, nil, err } diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go index 513e929c..b4aeab71 100644 --- a/backend/internal/repository/proxy_probe_service.go +++ b/backend/internal/repository/proxy_probe_service.go @@ -19,10 +19,14 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber { insecure := false allowPrivate := false validateResolvedIP := true + maxResponseBytes := defaultProxyProbeResponseMaxBytes if cfg != nil { insecure = cfg.Security.ProxyProbe.InsecureSkipVerify allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts validateResolvedIP = cfg.Security.URLAllowlist.Enabled + if cfg.Gateway.ProxyProbeResponseReadMaxBytes > 0 { + maxResponseBytes = cfg.Gateway.ProxyProbeResponseReadMaxBytes + } } if insecure { log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.") @@ -31,11 +35,13 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber { insecureSkipVerify: insecure, allowPrivateHosts: allowPrivate, validateResolvedIP: validateResolvedIP, + maxResponseBytes: maxResponseBytes, } } const ( - defaultProxyProbeTimeout = 30 * time.Second + defaultProxyProbeTimeout = 30 * time.Second + defaultProxyProbeResponseMaxBytes = int64(1024 * 1024) ) // probeURLs 按优先级排列的探测 URL 列表 @@ -52,6 +58,7 @@ type proxyProbeService struct { insecureSkipVerify bool allowPrivateHosts bool validateResolvedIP bool + maxResponseBytes int64 } func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) { @@ -59,7 +66,6 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s ProxyURL: proxyURL, Timeout: defaultProxyProbeTimeout, InsecureSkipVerify: s.insecureSkipVerify, - ProxyStrict: true, ValidateResolvedIP: s.validateResolvedIP, AllowPrivateHosts: s.allowPrivateHosts, }) @@ -98,10 +104,17 @@ func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Clien return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode) } - body, err := io.ReadAll(resp.Body) + maxResponseBytes := s.maxResponseBytes + if maxResponseBytes <= 0 { + maxResponseBytes = defaultProxyProbeResponseMaxBytes + } + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes+1)) if err != nil { return nil, latencyMs, fmt.Errorf("failed to read response: %w", err) } + if int64(len(body)) > maxResponseBytes { + return nil, latencyMs, fmt.Errorf("proxy probe response exceeds limit: %d", maxResponseBytes) + } switch parser { case "ip-api": diff --git a/backend/internal/repository/req_client_pool.go b/backend/internal/repository/req_client_pool.go index af71a7ee..79b24396 100644 --- a/backend/internal/repository/req_client_pool.go +++ b/backend/internal/repository/req_client_pool.go @@ -6,6 +6,8 @@ import ( "sync" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" + "github.com/imroc/req/v3" ) @@ -33,11 +35,11 @@ var sharedReqClients sync.Map // getSharedReqClient 获取共享的 req 客户端实例 // 性能优化:相同配置复用同一客户端,避免重复创建 -func getSharedReqClient(opts reqClientOptions) *req.Client { +func getSharedReqClient(opts reqClientOptions) (*req.Client, error) { key := buildReqClientKey(opts) if cached, ok := sharedReqClients.Load(key); ok { if c, ok := cached.(*req.Client); ok { - return c + return c, nil } } @@ -48,15 +50,19 @@ func getSharedReqClient(opts reqClientOptions) *req.Client { if opts.Impersonate { client = client.ImpersonateChrome() } - if strings.TrimSpace(opts.ProxyURL) != "" { - client.SetProxyURL(strings.TrimSpace(opts.ProxyURL)) + trimmed, _, err := proxyurl.Parse(opts.ProxyURL) + if err != nil { + return nil, err + } + if trimmed != "" { + client.SetProxyURL(trimmed) } actual, _ := sharedReqClients.LoadOrStore(key, client) if c, ok := actual.(*req.Client); ok { - return c + return c, nil } - return client + return client, nil } func buildReqClientKey(opts reqClientOptions) string { diff --git a/backend/internal/repository/req_client_pool_test.go b/backend/internal/repository/req_client_pool_test.go index 904ed4d6..9067d012 100644 --- a/backend/internal/repository/req_client_pool_test.go +++ b/backend/internal/repository/req_client_pool_test.go @@ -26,11 +26,13 @@ func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) { ProxyURL: "http://proxy.local:8080", Timeout: time.Second, } - clientDefault := getSharedReqClient(base) + clientDefault, err := getSharedReqClient(base) + require.NoError(t, err) force := base force.ForceHTTP2 = true - clientForce := getSharedReqClient(force) + clientForce, err := getSharedReqClient(force) + require.NoError(t, err) require.NotSame(t, clientDefault, clientForce) require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force)) @@ -42,8 +44,10 @@ func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) { ProxyURL: "http://proxy.local:8080", Timeout: 2 * time.Second, } - first := getSharedReqClient(opts) - second := getSharedReqClient(opts) + first, err := getSharedReqClient(opts) + require.NoError(t, err) + second, err := getSharedReqClient(opts) + require.NoError(t, err) require.Same(t, first, second) } @@ -56,7 +60,8 @@ func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) { key := buildReqClientKey(opts) sharedReqClients.Store(key, "invalid") - client := getSharedReqClient(opts) + client, err := getSharedReqClient(opts) + require.NoError(t, err) require.NotNil(t, client) loaded, ok := sharedReqClients.Load(key) @@ -71,20 +76,45 @@ func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) { Timeout: 4 * time.Second, Impersonate: true, } - client := getSharedReqClient(opts) + client, err := getSharedReqClient(opts) + require.NoError(t, err) require.NotNil(t, client) require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts)) } +func TestGetSharedReqClient_InvalidProxyURL(t *testing.T) { + sharedReqClients = sync.Map{} + opts := reqClientOptions{ + ProxyURL: "://missing-scheme", + Timeout: time.Second, + } + _, err := getSharedReqClient(opts) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid proxy URL") +} + +func TestGetSharedReqClient_ProxyURLMissingHost(t *testing.T) { + sharedReqClients = sync.Map{} + opts := reqClientOptions{ + ProxyURL: "http://", + Timeout: time.Second, + } + _, err := getSharedReqClient(opts) + require.Error(t, err) + require.Contains(t, err.Error(), "proxy URL missing host") +} + func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) { sharedReqClients = sync.Map{} - client := createOpenAIReqClient("http://proxy.local:8080") + client, err := createOpenAIReqClient("http://proxy.local:8080") + require.NoError(t, err) require.Equal(t, 120*time.Second, client.GetClient().Timeout) } func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) { sharedReqClients = sync.Map{} - client := createGeminiReqClient("http://proxy.local:8080") + client, err := createGeminiReqClient("http://proxy.local:8080") + require.NoError(t, err) require.Equal(t, "", forceHTTPVersion(t, client)) } diff --git a/backend/internal/repository/rpm_cache.go b/backend/internal/repository/rpm_cache.go new file mode 100644 index 00000000..4d73ec4b --- /dev/null +++ b/backend/internal/repository/rpm_cache.go @@ -0,0 +1,141 @@ +package repository + +import ( + "context" + "errors" + "fmt" + "strconv" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +// RPM 计数器缓存常量定义 +// +// 设计说明: +// 使用 Redis 简单计数器跟踪每个账号每分钟的请求数: +// - Key: rpm:{accountID}:{minuteTimestamp} +// - Value: 当前分钟内的请求计数 +// - TTL: 120 秒(覆盖当前分钟 + 一定冗余) +// +// 使用 TxPipeline(MULTI/EXEC)执行 INCR + EXPIRE,保证原子性且兼容 Redis Cluster。 +// 通过 rdb.Time() 获取服务端时间,避免多实例时钟不同步。 +// +// 设计决策: +// - TxPipeline vs Pipeline:Pipeline 仅合并发送但不保证原子,TxPipeline 使用 MULTI/EXEC 事务保证原子执行。 +// - rdb.Time() 单独调用:Pipeline/TxPipeline 中无法引用前一命令的结果,因此 TIME 必须单独调用(2 RTT)。 +// Lua 脚本可以做到 1 RTT,但在 Redis Cluster 中动态拼接 key 存在 CROSSSLOT 风险,选择安全性优先。 +const ( + // RPM 计数器键前缀 + // 格式: rpm:{accountID}:{minuteTimestamp} + rpmKeyPrefix = "rpm:" + + // RPM 计数器 TTL(120 秒,覆盖当前分钟窗口 + 冗余) + rpmKeyTTL = 120 * time.Second +) + +// RPMCacheImpl RPM 计数器缓存 Redis 实现 +type RPMCacheImpl struct { + rdb *redis.Client +} + +// NewRPMCache 创建 RPM 计数器缓存 +func NewRPMCache(rdb *redis.Client) service.RPMCache { + return &RPMCacheImpl{rdb: rdb} +} + +// currentMinuteKey 获取当前分钟的完整 Redis key +// 使用 rdb.Time() 获取 Redis 服务端时间,避免多实例时钟偏差 +func (c *RPMCacheImpl) currentMinuteKey(ctx context.Context, accountID int64) (string, error) { + serverTime, err := c.rdb.Time(ctx).Result() + if err != nil { + return "", fmt.Errorf("redis TIME: %w", err) + } + minuteTS := serverTime.Unix() / 60 + return fmt.Sprintf("%s%d:%d", rpmKeyPrefix, accountID, minuteTS), nil +} + +// currentMinuteSuffix 获取当前分钟时间戳后缀(供批量操作使用) +// 使用 rdb.Time() 获取 Redis 服务端时间 +func (c *RPMCacheImpl) currentMinuteSuffix(ctx context.Context) (string, error) { + serverTime, err := c.rdb.Time(ctx).Result() + if err != nil { + return "", fmt.Errorf("redis TIME: %w", err) + } + minuteTS := serverTime.Unix() / 60 + return strconv.FormatInt(minuteTS, 10), nil +} + +// IncrementRPM 原子递增并返回当前分钟的计数 +// 使用 TxPipeline (MULTI/EXEC) 执行 INCR + EXPIRE,保证原子性且兼容 Redis Cluster +func (c *RPMCacheImpl) IncrementRPM(ctx context.Context, accountID int64) (int, error) { + key, err := c.currentMinuteKey(ctx, accountID) + if err != nil { + return 0, fmt.Errorf("rpm increment: %w", err) + } + + // 使用 TxPipeline (MULTI/EXEC) 保证 INCR + EXPIRE 原子执行 + // EXPIRE 幂等,每次都设置不影响正确性 + pipe := c.rdb.TxPipeline() + incrCmd := pipe.Incr(ctx, key) + pipe.Expire(ctx, key, rpmKeyTTL) + + if _, err := pipe.Exec(ctx); err != nil { + return 0, fmt.Errorf("rpm increment: %w", err) + } + + return int(incrCmd.Val()), nil +} + +// GetRPM 获取当前分钟的 RPM 计数 +func (c *RPMCacheImpl) GetRPM(ctx context.Context, accountID int64) (int, error) { + key, err := c.currentMinuteKey(ctx, accountID) + if err != nil { + return 0, fmt.Errorf("rpm get: %w", err) + } + + val, err := c.rdb.Get(ctx, key).Int() + if errors.Is(err, redis.Nil) { + return 0, nil // 当前分钟无记录 + } + if err != nil { + return 0, fmt.Errorf("rpm get: %w", err) + } + return val, nil +} + +// GetRPMBatch 批量获取多个账号的 RPM 计数(使用 Pipeline) +func (c *RPMCacheImpl) GetRPMBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { + if len(accountIDs) == 0 { + return map[int64]int{}, nil + } + + // 获取当前分钟后缀 + minuteSuffix, err := c.currentMinuteSuffix(ctx) + if err != nil { + return nil, fmt.Errorf("rpm batch get: %w", err) + } + + // 使用 Pipeline 批量 GET + pipe := c.rdb.Pipeline() + cmds := make(map[int64]*redis.StringCmd, len(accountIDs)) + for _, id := range accountIDs { + key := fmt.Sprintf("%s%d:%s", rpmKeyPrefix, id, minuteSuffix) + cmds[id] = pipe.Get(ctx, key) + } + + if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("rpm batch get: %w", err) + } + + result := make(map[int64]int, len(accountIDs)) + for id, cmd := range cmds { + if val, err := cmd.Int(); err == nil { + result[id] = val + } else { + result[id] = 0 + } + } + return result, nil +} diff --git a/backend/internal/repository/security_secret_bootstrap.go b/backend/internal/repository/security_secret_bootstrap.go new file mode 100644 index 00000000..e773c238 --- /dev/null +++ b/backend/internal/repository/security_secret_bootstrap.go @@ -0,0 +1,177 @@ +package repository + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/hex" + "errors" + "fmt" + "log" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" + "github.com/Wei-Shaw/sub2api/internal/config" +) + +const ( + securitySecretKeyJWT = "jwt_secret" + securitySecretReadRetryMax = 5 + securitySecretReadRetryWait = 10 * time.Millisecond +) + +var readRandomBytes = rand.Read + +func ensureBootstrapSecrets(ctx context.Context, client *ent.Client, cfg *config.Config) error { + if client == nil { + return fmt.Errorf("nil ent client") + } + if cfg == nil { + return fmt.Errorf("nil config") + } + + cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) + if cfg.JWT.Secret != "" { + storedSecret, err := createSecuritySecretIfAbsent(ctx, client, securitySecretKeyJWT, cfg.JWT.Secret) + if err != nil { + return fmt.Errorf("persist jwt secret: %w", err) + } + if storedSecret != cfg.JWT.Secret { + log.Println("Warning: configured JWT secret mismatches persisted value; using persisted secret for cross-instance consistency.") + } + cfg.JWT.Secret = storedSecret + return nil + } + + secret, created, err := getOrCreateGeneratedSecuritySecret(ctx, client, securitySecretKeyJWT, 32) + if err != nil { + return fmt.Errorf("ensure jwt secret: %w", err) + } + cfg.JWT.Secret = secret + + if created { + log.Println("Warning: JWT secret auto-generated and persisted to database. Consider rotating to a managed secret for production.") + } + return nil +} + +func getOrCreateGeneratedSecuritySecret(ctx context.Context, client *ent.Client, key string, byteLength int) (string, bool, error) { + existing, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx) + if err == nil { + value := strings.TrimSpace(existing.Value) + if len([]byte(value)) < 32 { + return "", false, fmt.Errorf("stored secret %q must be at least 32 bytes", key) + } + return value, false, nil + } + if !ent.IsNotFound(err) { + return "", false, err + } + + generated, err := generateHexSecret(byteLength) + if err != nil { + return "", false, err + } + + if err := client.SecuritySecret.Create(). + SetKey(key). + SetValue(generated). + OnConflictColumns(securitysecret.FieldKey). + DoNothing(). + Exec(ctx); err != nil { + if !isSQLNoRowsError(err) { + return "", false, err + } + } + + stored, err := querySecuritySecretWithRetry(ctx, client, key) + if err != nil { + return "", false, err + } + value := strings.TrimSpace(stored.Value) + if len([]byte(value)) < 32 { + return "", false, fmt.Errorf("stored secret %q must be at least 32 bytes", key) + } + return value, value == generated, nil +} + +func createSecuritySecretIfAbsent(ctx context.Context, client *ent.Client, key, value string) (string, error) { + value = strings.TrimSpace(value) + if len([]byte(value)) < 32 { + return "", fmt.Errorf("secret %q must be at least 32 bytes", key) + } + + if err := client.SecuritySecret.Create(). + SetKey(key). + SetValue(value). + OnConflictColumns(securitysecret.FieldKey). + DoNothing(). + Exec(ctx); err != nil { + if !isSQLNoRowsError(err) { + return "", err + } + } + + stored, err := querySecuritySecretWithRetry(ctx, client, key) + if err != nil { + return "", err + } + storedValue := strings.TrimSpace(stored.Value) + if len([]byte(storedValue)) < 32 { + return "", fmt.Errorf("stored secret %q must be at least 32 bytes", key) + } + return storedValue, nil +} + +func querySecuritySecretWithRetry(ctx context.Context, client *ent.Client, key string) (*ent.SecuritySecret, error) { + var lastErr error + for attempt := 0; attempt <= securitySecretReadRetryMax; attempt++ { + stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx) + if err == nil { + return stored, nil + } + if !isSecretNotFoundError(err) { + return nil, err + } + lastErr = err + if attempt == securitySecretReadRetryMax { + break + } + + timer := time.NewTimer(securitySecretReadRetryWait) + select { + case <-ctx.Done(): + timer.Stop() + return nil, ctx.Err() + case <-timer.C: + } + } + return nil, lastErr +} + +func isSecretNotFoundError(err error) bool { + if err == nil { + return false + } + return ent.IsNotFound(err) || isSQLNoRowsError(err) +} + +func isSQLNoRowsError(err error) bool { + if err == nil { + return false + } + return errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), "no rows in result set") +} + +func generateHexSecret(byteLength int) (string, error) { + if byteLength <= 0 { + byteLength = 32 + } + buf := make([]byte, byteLength) + if _, err := readRandomBytes(buf); err != nil { + return "", fmt.Errorf("generate random secret: %w", err) + } + return hex.EncodeToString(buf), nil +} diff --git a/backend/internal/repository/security_secret_bootstrap_test.go b/backend/internal/repository/security_secret_bootstrap_test.go new file mode 100644 index 00000000..288edf33 --- /dev/null +++ b/backend/internal/repository/security_secret_bootstrap_test.go @@ -0,0 +1,337 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/hex" + "errors" + "fmt" + "strings" + "sync" + "testing" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newSecuritySecretTestClient(t *testing.T) *dbent.Client { + t.Helper() + name := strings.ReplaceAll(t.Name(), "/", "_") + dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", name) + + db, err := sql.Open("sqlite", dsn) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + return client +} + +func TestEnsureBootstrapSecretsNilInputs(t *testing.T) { + err := ensureBootstrapSecrets(context.Background(), nil, &config.Config{}) + require.Error(t, err) + require.Contains(t, err.Error(), "nil ent client") + + client := newSecuritySecretTestClient(t) + err = ensureBootstrapSecrets(context.Background(), client, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "nil config") +} + +func TestEnsureBootstrapSecretsGenerateAndPersistJWTSecret(t *testing.T) { + client := newSecuritySecretTestClient(t) + cfg := &config.Config{} + + err := ensureBootstrapSecrets(context.Background(), client, cfg) + require.NoError(t, err) + require.NotEmpty(t, cfg.JWT.Secret) + require.GreaterOrEqual(t, len([]byte(cfg.JWT.Secret)), 32) + + stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background()) + require.NoError(t, err) + require.Equal(t, cfg.JWT.Secret, stored.Value) +} + +func TestEnsureBootstrapSecretsLoadExistingJWTSecret(t *testing.T) { + client := newSecuritySecretTestClient(t) + _, err := client.SecuritySecret.Create().SetKey(securitySecretKeyJWT).SetValue("existing-jwt-secret-32bytes-long!!!!").Save(context.Background()) + require.NoError(t, err) + + cfg := &config.Config{} + err = ensureBootstrapSecrets(context.Background(), client, cfg) + require.NoError(t, err) + require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", cfg.JWT.Secret) +} + +func TestEnsureBootstrapSecretsRejectInvalidStoredSecret(t *testing.T) { + client := newSecuritySecretTestClient(t) + _, err := client.SecuritySecret.Create().SetKey(securitySecretKeyJWT).SetValue("too-short").Save(context.Background()) + require.NoError(t, err) + + cfg := &config.Config{} + err = ensureBootstrapSecrets(context.Background(), client, cfg) + require.Error(t, err) + require.Contains(t, err.Error(), "at least 32 bytes") +} + +func TestEnsureBootstrapSecretsPersistConfiguredJWTSecret(t *testing.T) { + client := newSecuritySecretTestClient(t) + cfg := &config.Config{ + JWT: config.JWTConfig{Secret: "configured-jwt-secret-32bytes-long!!"}, + } + + err := ensureBootstrapSecrets(context.Background(), client, cfg) + require.NoError(t, err) + + stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background()) + require.NoError(t, err) + require.Equal(t, "configured-jwt-secret-32bytes-long!!", stored.Value) +} + +func TestEnsureBootstrapSecretsConfiguredSecretTooShort(t *testing.T) { + client := newSecuritySecretTestClient(t) + cfg := &config.Config{JWT: config.JWTConfig{Secret: "short"}} + + err := ensureBootstrapSecrets(context.Background(), client, cfg) + require.Error(t, err) + require.Contains(t, err.Error(), "at least 32 bytes") +} + +func TestEnsureBootstrapSecretsConfiguredSecretDuplicateIgnored(t *testing.T) { + client := newSecuritySecretTestClient(t) + _, err := client.SecuritySecret.Create(). + SetKey(securitySecretKeyJWT). + SetValue("existing-jwt-secret-32bytes-long!!!!"). + Save(context.Background()) + require.NoError(t, err) + + cfg := &config.Config{JWT: config.JWTConfig{Secret: "another-configured-jwt-secret-32!!!!"}} + err = ensureBootstrapSecrets(context.Background(), client, cfg) + require.NoError(t, err) + + stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background()) + require.NoError(t, err) + require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", stored.Value) + require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", cfg.JWT.Secret) +} + +func TestGetOrCreateGeneratedSecuritySecretTrimmedExistingValue(t *testing.T) { + client := newSecuritySecretTestClient(t) + _, err := client.SecuritySecret.Create(). + SetKey("trimmed_key"). + SetValue(" existing-trimmed-secret-32bytes-long!! "). + Save(context.Background()) + require.NoError(t, err) + + value, created, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, "trimmed_key", 32) + require.NoError(t, err) + require.False(t, created) + require.Equal(t, "existing-trimmed-secret-32bytes-long!!", value) +} + +func TestGetOrCreateGeneratedSecuritySecretQueryError(t *testing.T) { + client := newSecuritySecretTestClient(t) + require.NoError(t, client.Close()) + + _, _, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, "closed_client_key", 32) + require.Error(t, err) +} + +func TestGetOrCreateGeneratedSecuritySecretCreateValidationError(t *testing.T) { + client := newSecuritySecretTestClient(t) + tooLongKey := strings.Repeat("k", 101) + + _, _, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, tooLongKey, 32) + require.Error(t, err) +} + +func TestGetOrCreateGeneratedSecuritySecretConcurrentCreation(t *testing.T) { + client := newSecuritySecretTestClient(t) + const goroutines = 8 + key := "concurrent_bootstrap_key" + + values := make([]string, goroutines) + createdFlags := make([]bool, goroutines) + errs := make([]error, goroutines) + + var wg sync.WaitGroup + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + values[idx], createdFlags[idx], errs[idx] = getOrCreateGeneratedSecuritySecret(context.Background(), client, key, 32) + }(i) + } + wg.Wait() + + for i := range errs { + require.NoError(t, errs[i]) + require.NotEmpty(t, values[i]) + } + for i := 1; i < len(values); i++ { + require.Equal(t, values[0], values[i]) + } + + createdCount := 0 + for _, created := range createdFlags { + if created { + createdCount++ + } + } + require.GreaterOrEqual(t, createdCount, 1) + require.LessOrEqual(t, createdCount, 1) + + count, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Count(context.Background()) + require.NoError(t, err) + require.Equal(t, 1, count) +} + +func TestGetOrCreateGeneratedSecuritySecretGenerateError(t *testing.T) { + client := newSecuritySecretTestClient(t) + originalRead := readRandomBytes + readRandomBytes = func([]byte) (int, error) { + return 0, errors.New("boom") + } + t.Cleanup(func() { + readRandomBytes = originalRead + }) + + _, _, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, "gen_error_key", 32) + require.Error(t, err) + require.Contains(t, err.Error(), "boom") +} + +func TestCreateSecuritySecretIfAbsent(t *testing.T) { + client := newSecuritySecretTestClient(t) + + _, err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "short") + require.Error(t, err) + require.Contains(t, err.Error(), "at least 32 bytes") + + stored, err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "valid-jwt-secret-value-32bytes-long") + require.NoError(t, err) + require.Equal(t, "valid-jwt-secret-value-32bytes-long", stored) + + stored, err = createSecuritySecretIfAbsent(context.Background(), client, "abc", "another-valid-secret-value-32bytes") + require.NoError(t, err) + require.Equal(t, "valid-jwt-secret-value-32bytes-long", stored) + + count, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ("abc")).Count(context.Background()) + require.NoError(t, err) + require.Equal(t, 1, count) +} + +func TestCreateSecuritySecretIfAbsentValidationError(t *testing.T) { + client := newSecuritySecretTestClient(t) + _, err := createSecuritySecretIfAbsent( + context.Background(), + client, + strings.Repeat("k", 101), + "valid-jwt-secret-value-32bytes-long", + ) + require.Error(t, err) +} + +func TestCreateSecuritySecretIfAbsentExecError(t *testing.T) { + client := newSecuritySecretTestClient(t) + require.NoError(t, client.Close()) + + _, err := createSecuritySecretIfAbsent(context.Background(), client, "closed-client-key", "valid-jwt-secret-value-32bytes-long") + require.Error(t, err) +} + +func TestQuerySecuritySecretWithRetrySuccess(t *testing.T) { + client := newSecuritySecretTestClient(t) + created, err := client.SecuritySecret.Create(). + SetKey("retry_success_key"). + SetValue("retry-success-jwt-secret-value-32!!"). + Save(context.Background()) + require.NoError(t, err) + + got, err := querySecuritySecretWithRetry(context.Background(), client, "retry_success_key") + require.NoError(t, err) + require.Equal(t, created.ID, got.ID) + require.Equal(t, "retry-success-jwt-secret-value-32!!", got.Value) +} + +func TestQuerySecuritySecretWithRetryExhausted(t *testing.T) { + client := newSecuritySecretTestClient(t) + + _, err := querySecuritySecretWithRetry(context.Background(), client, "retry_missing_key") + require.Error(t, err) + require.True(t, isSecretNotFoundError(err)) +} + +func TestQuerySecuritySecretWithRetryContextCanceled(t *testing.T) { + client := newSecuritySecretTestClient(t) + ctx, cancel := context.WithTimeout(context.Background(), securitySecretReadRetryWait/2) + defer cancel() + + _, err := querySecuritySecretWithRetry(ctx, client, "retry_ctx_cancel_key") + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) +} + +func TestQuerySecuritySecretWithRetryNonNotFoundError(t *testing.T) { + client := newSecuritySecretTestClient(t) + require.NoError(t, client.Close()) + + _, err := querySecuritySecretWithRetry(context.Background(), client, "retry_closed_client_key") + require.Error(t, err) + require.False(t, isSecretNotFoundError(err)) +} + +func TestSecretNotFoundHelpers(t *testing.T) { + require.False(t, isSecretNotFoundError(nil)) + require.False(t, isSQLNoRowsError(nil)) + + require.True(t, isSQLNoRowsError(sql.ErrNoRows)) + require.True(t, isSQLNoRowsError(fmt.Errorf("wrapped: %w", sql.ErrNoRows))) + require.True(t, isSQLNoRowsError(errors.New("sql: no rows in result set"))) + + require.True(t, isSecretNotFoundError(sql.ErrNoRows)) + require.True(t, isSecretNotFoundError(errors.New("sql: no rows in result set"))) + require.False(t, isSecretNotFoundError(errors.New("some other error"))) +} + +func TestGenerateHexSecretReadError(t *testing.T) { + originalRead := readRandomBytes + readRandomBytes = func([]byte) (int, error) { + return 0, errors.New("read random failed") + } + t.Cleanup(func() { + readRandomBytes = originalRead + }) + + _, err := generateHexSecret(32) + require.Error(t, err) + require.Contains(t, err.Error(), "read random failed") +} + +func TestGenerateHexSecretLengths(t *testing.T) { + v1, err := generateHexSecret(0) + require.NoError(t, err) + require.Len(t, v1, 64) + _, err = hex.DecodeString(v1) + require.NoError(t, err) + + v2, err := generateHexSecret(16) + require.NoError(t, err) + require.Len(t, v2, 32) + _, err = hex.DecodeString(v2) + require.NoError(t, err) + + require.NotEqual(t, v1, v2) +} diff --git a/backend/internal/repository/setting_repo_integration_test.go b/backend/internal/repository/setting_repo_integration_test.go index 147313d6..f37b2de1 100644 --- a/backend/internal/repository/setting_repo_integration_test.go +++ b/backend/internal/repository/setting_repo_integration_test.go @@ -122,7 +122,7 @@ func (s *SettingRepoSuite) TestSet_EmptyValue() { func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() { // 模拟保存站点设置,部分字段有值,部分字段为空 settings := map[string]string{ - "site_name": "AICodex2API", + "site_name": "Sub2api", "site_subtitle": "Subscription to API", "site_logo": "", // 用户未上传Logo "api_base_url": "", // 用户未设置API地址 @@ -136,7 +136,7 @@ func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() { result, err := s.repo.GetMultiple(s.ctx, []string{"site_name", "site_subtitle", "site_logo", "api_base_url", "contact_info", "doc_url"}) s.Require().NoError(err, "GetMultiple after SetMultiple with empty values") - s.Require().Equal("AICodex2API", result["site_name"]) + s.Require().Equal("Sub2api", result["site_name"]) s.Require().Equal("Subscription to API", result["site_subtitle"]) s.Require().Equal("", result["site_logo"], "empty site_logo should be preserved") s.Require().Equal("", result["api_base_url"], "empty api_base_url should be preserved") diff --git a/backend/internal/repository/soft_delete_ent_integration_test.go b/backend/internal/repository/soft_delete_ent_integration_test.go index ef63fbee..8c2b23da 100644 --- a/backend/internal/repository/soft_delete_ent_integration_test.go +++ b/backend/internal/repository/soft_delete_ent_integration_test.go @@ -41,7 +41,7 @@ func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) { u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com") - repo := NewAPIKeyRepository(client) + repo := NewAPIKeyRepository(client, integrationDB) key := &service.APIKey{ UserID: u.ID, Key: uniqueSoftDeleteValue(t, "sk-soft-delete"), @@ -73,7 +73,7 @@ func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) { u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com") - repo := NewAPIKeyRepository(client) + repo := NewAPIKeyRepository(client, integrationDB) key := &service.APIKey{ UserID: u.ID, Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"), @@ -93,7 +93,7 @@ func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) { u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com") - repo := NewAPIKeyRepository(client) + repo := NewAPIKeyRepository(client, integrationDB) key := &service.APIKey{ UserID: u.ID, Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"), diff --git a/backend/internal/repository/sora_account_repo.go b/backend/internal/repository/sora_account_repo.go new file mode 100644 index 00000000..ad2ae638 --- /dev/null +++ b/backend/internal/repository/sora_account_repo.go @@ -0,0 +1,98 @@ +package repository + +import ( + "context" + "database/sql" + "errors" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// soraAccountRepository 实现 service.SoraAccountRepository 接口。 +// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。 +// +// 设计说明: +// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理 +// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义 +// - 与 accounts 主表通过外键关联,ON DELETE CASCADE 确保级联删除 +type soraAccountRepository struct { + sql *sql.DB +} + +// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例 +func NewSoraAccountRepository(sqlDB *sql.DB) service.SoraAccountRepository { + return &soraAccountRepository{sql: sqlDB} +} + +// Upsert 创建或更新 Sora 账号扩展信息 +// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert +func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error { + accessToken, accessOK := updates["access_token"].(string) + refreshToken, refreshOK := updates["refresh_token"].(string) + sessionToken, sessionOK := updates["session_token"].(string) + + if !accessOK || accessToken == "" || !refreshOK || refreshToken == "" { + if !sessionOK { + return errors.New("缺少 access_token/refresh_token,且未提供可更新字段") + } + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_accounts + SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END, + updated_at = NOW() + WHERE account_id = $1 + `, accountID, sessionToken) + if err != nil { + return err + } + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return errors.New("sora_accounts 记录不存在,无法仅更新 session_token") + } + return nil + } + + _, err := r.sql.ExecContext(ctx, ` + INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at) + VALUES ($1, $2, $3, $4, NOW(), NOW()) + ON CONFLICT (account_id) DO UPDATE SET + access_token = EXCLUDED.access_token, + refresh_token = EXCLUDED.refresh_token, + session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END, + updated_at = NOW() + `, accountID, accessToken, refreshToken, sessionToken) + return err +} + +// GetByAccountID 根据账号 ID 获取 Sora 扩展信息 +func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) { + rows, err := r.sql.QueryContext(ctx, ` + SELECT account_id, access_token, refresh_token, COALESCE(session_token, '') + FROM sora_accounts + WHERE account_id = $1 + `, accountID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + return nil, nil // 记录不存在 + } + + var sa service.SoraAccount + if err := rows.Scan(&sa.AccountID, &sa.AccessToken, &sa.RefreshToken, &sa.SessionToken); err != nil { + return nil, err + } + return &sa, nil +} + +// Delete 删除 Sora 账号扩展信息 +func (r *soraAccountRepository) Delete(ctx context.Context, accountID int64) error { + _, err := r.sql.ExecContext(ctx, ` + DELETE FROM sora_accounts WHERE account_id = $1 + `, accountID) + return err +} diff --git a/backend/internal/repository/sora_generation_repo.go b/backend/internal/repository/sora_generation_repo.go new file mode 100644 index 00000000..aaf3cb2f --- /dev/null +++ b/backend/internal/repository/sora_generation_repo.go @@ -0,0 +1,419 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。 +// 使用原生 SQL 操作 sora_generations 表。 +type soraGenerationRepository struct { + sql *sql.DB +} + +// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。 +func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository { + return &soraGenerationRepository{sql: sqlDB} +} + +func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error { + mediaURLsJSON, _ := json.Marshal(gen.MediaURLs) + s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys) + + err := r.sql.QueryRowContext(ctx, ` + INSERT INTO sora_generations ( + user_id, api_key_id, model, prompt, media_type, + status, media_url, media_urls, file_size_bytes, + storage_type, s3_object_keys, upstream_task_id, error_message + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + RETURNING id, created_at + `, + gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType, + gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes, + gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage, + ).Scan(&gen.ID, &gen.CreatedAt) + return err +} + +// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。 +func (r *soraGenerationRepository) CreatePendingWithLimit( + ctx context.Context, + gen *service.SoraGeneration, + activeStatuses []string, + maxActive int64, +) error { + if gen == nil { + return fmt.Errorf("generation is nil") + } + if maxActive <= 0 { + return r.Create(ctx, gen) + } + if len(activeStatuses) == 0 { + activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating} + } + + tx, err := r.sql.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + // 使用用户级 advisory lock 串行化并发创建,避免超限竞态。 + if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil { + return err + } + + placeholders := make([]string, len(activeStatuses)) + args := make([]any, 0, 1+len(activeStatuses)) + args = append(args, gen.UserID) + for i, s := range activeStatuses { + placeholders[i] = fmt.Sprintf("$%d", i+2) + args = append(args, s) + } + countQuery := fmt.Sprintf( + `SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`, + strings.Join(placeholders, ","), + ) + var activeCount int64 + if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil { + return err + } + if activeCount >= maxActive { + return service.ErrSoraGenerationConcurrencyLimit + } + + mediaURLsJSON, _ := json.Marshal(gen.MediaURLs) + s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys) + if err := tx.QueryRowContext(ctx, ` + INSERT INTO sora_generations ( + user_id, api_key_id, model, prompt, media_type, + status, media_url, media_urls, file_size_bytes, + storage_type, s3_object_keys, upstream_task_id, error_message + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + RETURNING id, created_at + `, + gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType, + gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes, + gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage, + ).Scan(&gen.ID, &gen.CreatedAt); err != nil { + return err + } + + return tx.Commit() +} + +func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) { + gen := &service.SoraGeneration{} + var mediaURLsJSON, s3KeysJSON []byte + var completedAt sql.NullTime + var apiKeyID sql.NullInt64 + + err := r.sql.QueryRowContext(ctx, ` + SELECT id, user_id, api_key_id, model, prompt, media_type, + status, media_url, media_urls, file_size_bytes, + storage_type, s3_object_keys, upstream_task_id, error_message, + created_at, completed_at + FROM sora_generations WHERE id = $1 + `, id).Scan( + &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType, + &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes, + &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage, + &gen.CreatedAt, &completedAt, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("生成记录不存在") + } + return nil, err + } + + if apiKeyID.Valid { + gen.APIKeyID = &apiKeyID.Int64 + } + if completedAt.Valid { + gen.CompletedAt = &completedAt.Time + } + _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs) + _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys) + return gen, nil +} + +func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error { + mediaURLsJSON, _ := json.Marshal(gen.MediaURLs) + s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys) + + var completedAt *time.Time + if gen.CompletedAt != nil { + completedAt = gen.CompletedAt + } + + _, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations SET + status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5, + storage_type = $6, s3_object_keys = $7, upstream_task_id = $8, + error_message = $9, completed_at = $10 + WHERE id = $1 + `, + gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes, + gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, + gen.ErrorMessage, completedAt, + ) + return err +} + +// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。 +func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) { + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations + SET status = $2, upstream_task_id = $3 + WHERE id = $1 AND status = $4 + `, + id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending, + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。 +func (r *soraGenerationRepository) UpdateCompletedIfActive( + ctx context.Context, + id int64, + mediaURL string, + mediaURLs []string, + storageType string, + s3Keys []string, + fileSizeBytes int64, + completedAt time.Time, +) (bool, error) { + mediaURLsJSON, _ := json.Marshal(mediaURLs) + s3KeysJSON, _ := json.Marshal(s3Keys) + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations + SET status = $2, + media_url = $3, + media_urls = $4, + file_size_bytes = $5, + storage_type = $6, + s3_object_keys = $7, + error_message = '', + completed_at = $8 + WHERE id = $1 AND status IN ($9, $10) + `, + id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes, + storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating, + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。 +func (r *soraGenerationRepository) UpdateFailedIfActive( + ctx context.Context, + id int64, + errMsg string, + completedAt time.Time, +) (bool, error) { + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations + SET status = $2, + error_message = $3, + completed_at = $4 + WHERE id = $1 AND status IN ($5, $6) + `, + id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating, + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。 +func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) { + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations + SET status = $2, completed_at = $3 + WHERE id = $1 AND status IN ($4, $5) + `, + id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating, + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at)。 +func (r *soraGenerationRepository) UpdateStorageIfCompleted( + ctx context.Context, + id int64, + mediaURL string, + mediaURLs []string, + storageType string, + s3Keys []string, + fileSizeBytes int64, +) (bool, error) { + mediaURLsJSON, _ := json.Marshal(mediaURLs) + s3KeysJSON, _ := json.Marshal(s3Keys) + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations + SET media_url = $2, + media_urls = $3, + file_size_bytes = $4, + storage_type = $5, + s3_object_keys = $6 + WHERE id = $1 AND status = $7 + `, + id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted, + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error { + _, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id) + return err +} + +func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) { + // 构建 WHERE 条件 + conditions := []string{"user_id = $1"} + args := []any{params.UserID} + argIdx := 2 + + if params.Status != "" { + // 支持逗号分隔的多状态 + statuses := strings.Split(params.Status, ",") + placeholders := make([]string, len(statuses)) + for i, s := range statuses { + placeholders[i] = fmt.Sprintf("$%d", argIdx) + args = append(args, strings.TrimSpace(s)) + argIdx++ + } + conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ","))) + } + if params.StorageType != "" { + storageTypes := strings.Split(params.StorageType, ",") + placeholders := make([]string, len(storageTypes)) + for i, s := range storageTypes { + placeholders[i] = fmt.Sprintf("$%d", argIdx) + args = append(args, strings.TrimSpace(s)) + argIdx++ + } + conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ","))) + } + if params.MediaType != "" { + conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx)) + args = append(args, params.MediaType) + argIdx++ + } + + whereClause := "WHERE " + strings.Join(conditions, " AND ") + + // 计数 + var total int64 + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause) + if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { + return nil, 0, err + } + + // 分页查询 + offset := (params.Page - 1) * params.PageSize + listQuery := fmt.Sprintf(` + SELECT id, user_id, api_key_id, model, prompt, media_type, + status, media_url, media_urls, file_size_bytes, + storage_type, s3_object_keys, upstream_task_id, error_message, + created_at, completed_at + FROM sora_generations %s + ORDER BY created_at DESC + LIMIT $%d OFFSET $%d + `, whereClause, argIdx, argIdx+1) + args = append(args, params.PageSize, offset) + + rows, err := r.sql.QueryContext(ctx, listQuery, args...) + if err != nil { + return nil, 0, err + } + defer func() { + _ = rows.Close() + }() + + var results []*service.SoraGeneration + for rows.Next() { + gen := &service.SoraGeneration{} + var mediaURLsJSON, s3KeysJSON []byte + var completedAt sql.NullTime + var apiKeyID sql.NullInt64 + + if err := rows.Scan( + &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType, + &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes, + &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage, + &gen.CreatedAt, &completedAt, + ); err != nil { + return nil, 0, err + } + + if apiKeyID.Valid { + gen.APIKeyID = &apiKeyID.Int64 + } + if completedAt.Valid { + gen.CompletedAt = &completedAt.Time + } + _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs) + _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys) + results = append(results, gen) + } + + return results, total, rows.Err() +} + +func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) { + if len(statuses) == 0 { + return 0, nil + } + + placeholders := make([]string, len(statuses)) + args := []any{userID} + for i, s := range statuses { + placeholders[i] = fmt.Sprintf("$%d", i+2) + args = append(args, s) + } + + var count int64 + query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ",")) + err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count) + return count, err +} diff --git a/backend/internal/repository/usage_cleanup_repo.go b/backend/internal/repository/usage_cleanup_repo.go index 9c021357..1a25696e 100644 --- a/backend/internal/repository/usage_cleanup_repo.go +++ b/backend/internal/repository/usage_cleanup_repo.go @@ -362,7 +362,12 @@ func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any) idx++ } } - if filters.Stream != nil { + if filters.RequestType != nil { + condition, conditionArgs := buildRequestTypeFilterCondition(idx, *filters.RequestType) + conditions = append(conditions, condition) + args = append(args, conditionArgs...) + idx += len(conditionArgs) + } else if filters.Stream != nil { conditions = append(conditions, fmt.Sprintf("stream = $%d", idx)) args = append(args, *filters.Stream) idx++ diff --git a/backend/internal/repository/usage_cleanup_repo_test.go b/backend/internal/repository/usage_cleanup_repo_test.go index 0ca30ec7..1ac7cca5 100644 --- a/backend/internal/repository/usage_cleanup_repo_test.go +++ b/backend/internal/repository/usage_cleanup_repo_test.go @@ -466,6 +466,38 @@ func TestBuildUsageCleanupWhere(t *testing.T) { require.Equal(t, []any{start, end, userID, apiKeyID, accountID, groupID, "gpt-4", stream, billingType}, args) } +func TestBuildUsageCleanupWhereRequestTypePriority(t *testing.T) { + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + requestType := int16(service.RequestTypeWSV2) + stream := false + + where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{ + StartTime: start, + EndTime: end, + RequestType: &requestType, + Stream: &stream, + }) + + require.Equal(t, "created_at >= $1 AND created_at <= $2 AND (request_type = $3 OR (request_type = 0 AND openai_ws_mode = TRUE))", where) + require.Equal(t, []any{start, end, requestType}, args) +} + +func TestBuildUsageCleanupWhereRequestTypeLegacyFallback(t *testing.T) { + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + requestType := int16(service.RequestTypeStream) + + where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{ + StartTime: start, + EndTime: end, + RequestType: &requestType, + }) + + require.Equal(t, "created_at >= $1 AND created_at <= $2 AND (request_type = $3 OR (request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE))", where) + require.Equal(t, []any{start, end, requestType}, args) +} + func TestBuildUsageCleanupWhereModelEmpty(t *testing.T) { start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) end := start.Add(24 * time.Hour) diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 2db1764f..44079a55 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -22,7 +22,23 @@ import ( "github.com/lib/pq" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at" + +// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL +var dateFormatWhitelist = map[string]string{ + "hour": "YYYY-MM-DD HH24:00", + "day": "YYYY-MM-DD", + "week": "IYYY-IW", + "month": "YYYY-MM", +} + +// safeDateFormat 根据白名单获取 dateFormat,未匹配时返回默认值 +func safeDateFormat(granularity string) string { + if f, ok := dateFormatWhitelist[granularity]; ok { + return f + } + return "YYYY-MM-DD" +} type usageLogRepository struct { client *dbent.Client @@ -82,6 +98,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) log.RequestID = requestID rateMultiplier := log.RateMultiplier + log.SyncRequestTypeAndLegacyFields() + requestType := int16(log.RequestType) query := ` INSERT INTO usage_logs ( @@ -107,26 +125,30 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) rate_multiplier, account_rate_multiplier, billing_type, + request_type, stream, + openai_ws_mode, duration_ms, first_token_ms, user_agent, - ip_address, - image_count, - image_size, - reasoning_effort, - created_at - ) VALUES ( - $1, $2, $3, $4, $5, - $6, $7, - $8, $9, $10, $11, - $12, $13, - $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31 - ) - ON CONFLICT (request_id, api_key_id) DO NOTHING - RETURNING id, created_at - ` + ip_address, + image_count, + image_size, + media_type, + reasoning_effort, + cache_ttl_overridden, + created_at + ) VALUES ( + $1, $2, $3, $4, $5, + $6, $7, + $8, $9, $10, $11, + $12, $13, + $14, $15, $16, $17, $18, $19, + $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35 + ) + ON CONFLICT (request_id, api_key_id) DO NOTHING + RETURNING id, created_at + ` groupID := nullInt64(log.GroupID) subscriptionID := nullInt64(log.SubscriptionID) @@ -135,6 +157,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) userAgent := nullString(log.UserAgent) ipAddress := nullString(log.IPAddress) imageSize := nullString(log.ImageSize) + mediaType := nullString(log.MediaType) reasoningEffort := nullString(log.ReasoningEffort) var requestIDArg any @@ -165,14 +188,18 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) rateMultiplier, log.AccountRateMultiplier, log.BillingType, + requestType, log.Stream, + log.OpenAIWSMode, duration, firstToken, userAgent, ipAddress, log.ImageCount, imageSize, + mediaType, reasoningEffort, + log.CacheTTLOverridden, createdAt, } if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil { @@ -471,25 +498,46 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte } func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Context, stats *DashboardStats, startUTC, endUTC, todayUTC, now time.Time) error { - totalStatsQuery := ` + todayEnd := todayUTC.Add(24 * time.Hour) + combinedStatsQuery := ` + WITH scoped AS ( + SELECT + created_at, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + total_cost, + actual_cost, + COALESCE(duration_ms, 0) AS duration_ms + FROM usage_logs + WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz) + AND created_at < GREATEST($2::timestamptz, $4::timestamptz) + ) SELECT - COUNT(*) as total_requests, - COALESCE(SUM(input_tokens), 0) as total_input_tokens, - COALESCE(SUM(output_tokens), 0) as total_output_tokens, - COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens, - COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens, - COALESCE(SUM(total_cost), 0) as total_cost, - COALESCE(SUM(actual_cost), 0) as total_actual_cost, - COALESCE(SUM(COALESCE(duration_ms, 0)), 0) as total_duration_ms - FROM usage_logs - WHERE created_at >= $1 AND created_at < $2 + COUNT(*) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz) AS total_requests, + COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_input_tokens, + COALESCE(SUM(output_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_output_tokens, + COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_creation_tokens, + COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_read_tokens, + COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cost, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_actual_cost, + COALESCE(SUM(duration_ms) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_duration_ms, + COUNT(*) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz) AS today_requests, + COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_input_tokens, + COALESCE(SUM(output_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_output_tokens, + COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_creation_tokens, + COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_read_tokens, + COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cost, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_actual_cost + FROM scoped ` var totalDurationMs int64 if err := scanSingleRow( ctx, r.sql, - totalStatsQuery, - []any{startUTC, endUTC}, + combinedStatsQuery, + []any{startUTC, endUTC, todayUTC, todayEnd}, &stats.TotalRequests, &stats.TotalInputTokens, &stats.TotalOutputTokens, @@ -498,32 +546,6 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co &stats.TotalCost, &stats.TotalActualCost, &totalDurationMs, - ); err != nil { - return err - } - stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens - if stats.TotalRequests > 0 { - stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests) - } - - todayEnd := todayUTC.Add(24 * time.Hour) - todayStatsQuery := ` - SELECT - COUNT(*) as today_requests, - COALESCE(SUM(input_tokens), 0) as today_input_tokens, - COALESCE(SUM(output_tokens), 0) as today_output_tokens, - COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens, - COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens, - COALESCE(SUM(total_cost), 0) as today_cost, - COALESCE(SUM(actual_cost), 0) as today_actual_cost - FROM usage_logs - WHERE created_at >= $1 AND created_at < $2 - ` - if err := scanSingleRow( - ctx, - r.sql, - todayStatsQuery, - []any{todayUTC, todayEnd}, &stats.TodayRequests, &stats.TodayInputTokens, &stats.TodayOutputTokens, @@ -534,25 +556,28 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co ); err != nil { return err } - stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens - - activeUsersQuery := ` - SELECT COUNT(DISTINCT user_id) as active_users - FROM usage_logs - WHERE created_at >= $1 AND created_at < $2 - ` - if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd}, &stats.ActiveUsers); err != nil { - return err + stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens + if stats.TotalRequests > 0 { + stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests) } + stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens + hourStart := now.UTC().Truncate(time.Hour) hourEnd := hourStart.Add(time.Hour) - hourlyActiveQuery := ` - SELECT COUNT(DISTINCT user_id) as active_users - FROM usage_logs - WHERE created_at >= $1 AND created_at < $2 + activeUsersQuery := ` + WITH scoped AS ( + SELECT user_id, created_at + FROM usage_logs + WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz) + AND created_at < GREATEST($2::timestamptz, $4::timestamptz) + ) + SELECT + COUNT(DISTINCT CASE WHEN created_at >= $1::timestamptz AND created_at < $2::timestamptz THEN user_id END) AS active_users, + COUNT(DISTINCT CASE WHEN created_at >= $3::timestamptz AND created_at < $4::timestamptz THEN user_id END) AS hourly_active_users + FROM scoped ` - if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart, hourEnd}, &stats.HourlyActiveUsers); err != nil { + if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd, hourStart, hourEnd}, &stats.ActiveUsers, &stats.HourlyActiveUsers); err != nil { return err } @@ -564,7 +589,7 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, } func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" logs, err := r.queryUsageLogs(ctx, query, userID, startTime, endTime) return logs, nil, err } @@ -810,19 +835,19 @@ func resolveUsageStatsTimezone() string { } func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime) return logs, nil, err } func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" logs, err := r.queryUsageLogs(ctx, query, accountID, startTime, endTime) return logs, nil, err } func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime) return logs, nil, err } @@ -894,6 +919,114 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI return stats, nil } +// GetAccountWindowStatsBatch 批量获取同一窗口起点下多个账号的统计数据。 +// 返回 map[accountID]*AccountStats,未命中的账号会返回零值统计,便于上层直接复用。 +func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) { + result := make(map[int64]*usagestats.AccountStats, len(accountIDs)) + if len(accountIDs) == 0 { + return result, nil + } + + query := ` + SELECT + account_id, + COUNT(*) as requests, + COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, + COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost, + COALESCE(SUM(total_cost), 0) as standard_cost, + COALESCE(SUM(actual_cost), 0) as user_cost + FROM usage_logs + WHERE account_id = ANY($1) AND created_at >= $2 + GROUP BY account_id + ` + rows, err := r.sql.QueryContext(ctx, query, pq.Array(accountIDs), startTime) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var accountID int64 + stats := &usagestats.AccountStats{} + if err := rows.Scan( + &accountID, + &stats.Requests, + &stats.Tokens, + &stats.Cost, + &stats.StandardCost, + &stats.UserCost, + ); err != nil { + return nil, err + } + result[accountID] = stats + } + if err := rows.Err(); err != nil { + return nil, err + } + + for _, accountID := range accountIDs { + if _, ok := result[accountID]; !ok { + result[accountID] = &usagestats.AccountStats{} + } + } + return result, nil +} + +// GetGeminiUsageTotalsBatch 批量聚合 Gemini 账号在窗口内的 Pro/Flash 请求与用量。 +// 模型分类规则与 service.geminiModelClassFromName 一致:model 包含 flash/lite 视为 flash,其余视为 pro。 +func (r *usageLogRepository) GetGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, startTime, endTime time.Time) (map[int64]service.GeminiUsageTotals, error) { + result := make(map[int64]service.GeminiUsageTotals, len(accountIDs)) + if len(accountIDs) == 0 { + return result, nil + } + + query := ` + SELECT + account_id, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 1 ELSE 0 END), 0) AS flash_requests, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE 1 END), 0) AS pro_requests, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) ELSE 0 END), 0) AS flash_tokens, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) END), 0) AS pro_tokens, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN actual_cost ELSE 0 END), 0) AS flash_cost, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE actual_cost END), 0) AS pro_cost + FROM usage_logs + WHERE account_id = ANY($1) AND created_at >= $2 AND created_at < $3 + GROUP BY account_id + ` + rows, err := r.sql.QueryContext(ctx, query, pq.Array(accountIDs), startTime, endTime) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var accountID int64 + var totals service.GeminiUsageTotals + if err := rows.Scan( + &accountID, + &totals.FlashRequests, + &totals.ProRequests, + &totals.FlashTokens, + &totals.ProTokens, + &totals.FlashCost, + &totals.ProCost, + ); err != nil { + return nil, err + } + result[accountID] = totals + } + if err := rows.Err(); err != nil { + return nil, err + } + + for _, accountID := range accountIDs { + if _, ok := result[accountID]; !ok { + result[accountID] = service.GeminiUsageTotals{} + } + } + return result, nil +} + // TrendDataPoint represents a single point in trend data type TrendDataPoint = usagestats.TrendDataPoint @@ -908,10 +1041,7 @@ type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint // GetAPIKeyUsageTrend returns usage trend data grouped by API key and date func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) { - dateFormat := "YYYY-MM-DD" - if granularity == "hour" { - dateFormat = "YYYY-MM-DD HH24:00" - } + dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` WITH top_keys AS ( @@ -966,10 +1096,7 @@ func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, // GetUserUsageTrend returns usage trend data grouped by user and date func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) { - dateFormat := "YYYY-MM-DD" - if granularity == "hour" { - dateFormat = "YYYY-MM-DD HH24:00" - } + dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` WITH top_users AS ( @@ -1228,10 +1355,7 @@ func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKey // GetUserUsageTrendByUserID 获取指定用户的使用趋势 func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) { - dateFormat := "YYYY-MM-DD" - if granularity == "hour" { - dateFormat = "YYYY-MM-DD HH24:00" - } + dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` SELECT @@ -1334,10 +1458,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1)) args = append(args, filters.Model) } - if filters.Stream != nil { - conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1)) - args = append(args, *filters.Stream) - } + conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream) if filters.BillingType != nil { conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) args = append(args, int16(*filters.BillingType)) @@ -1352,7 +1473,16 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat } whereClause := buildWhere(conditions) - logs, page, err := r.listUsageLogsWithPagination(ctx, whereClause, args, params) + var ( + logs []service.UsageLog + page *pagination.PaginationResult + err error + ) + if shouldUseFastUsageLogTotal(filters) { + logs, page, err = r.listUsageLogsWithFastPagination(ctx, whereClause, args, params) + } else { + logs, page, err = r.listUsageLogsWithPagination(ctx, whereClause, args, params) + } if err != nil { return nil, nil, err } @@ -1363,71 +1493,86 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat return logs, page, nil } +func shouldUseFastUsageLogTotal(filters UsageLogFilters) bool { + if filters.ExactTotal { + return false + } + // 强选择过滤下记录集通常较小,保留精确总数。 + return filters.UserID == 0 && filters.APIKeyID == 0 && filters.AccountID == 0 +} + // UsageStats represents usage statistics type UsageStats = usagestats.UsageStats // BatchUserUsageStats represents usage stats for a single user type BatchUserUsageStats = usagestats.BatchUserUsageStats -// GetBatchUserUsageStats gets today and total actual_cost for multiple users -func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) { +func normalizePositiveInt64IDs(ids []int64) []int64 { + if len(ids) == 0 { + return nil + } + seen := make(map[int64]struct{}, len(ids)) + out := make([]int64, 0, len(ids)) + for _, id := range ids { + if id <= 0 { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + out = append(out, id) + } + return out +} + +// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range. +// If startTime is zero, defaults to 30 days ago. +func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) { result := make(map[int64]*BatchUserUsageStats) - if len(userIDs) == 0 { + normalizedUserIDs := normalizePositiveInt64IDs(userIDs) + if len(normalizedUserIDs) == 0 { return result, nil } - for _, id := range userIDs { + // 默认最近 30 天 + if startTime.IsZero() { + startTime = time.Now().AddDate(0, 0, -30) + } + if endTime.IsZero() { + endTime = time.Now() + } + + for _, id := range normalizedUserIDs { result[id] = &BatchUserUsageStats{UserID: id} } query := ` - SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost + SELECT + user_id, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost FROM usage_logs WHERE user_id = ANY($1) + AND created_at >= LEAST($2, $4) GROUP BY user_id ` - rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs)) + today := timezone.Today() + rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedUserIDs), startTime, endTime, today) if err != nil { return nil, err } for rows.Next() { var userID int64 var total float64 - if err := rows.Scan(&userID, &total); err != nil { + var todayTotal float64 + if err := rows.Scan(&userID, &total, &todayTotal); err != nil { _ = rows.Close() return nil, err } if stats, ok := result[userID]; ok { stats.TotalActualCost = total - } - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - - today := timezone.Today() - todayQuery := ` - SELECT user_id, COALESCE(SUM(actual_cost), 0) as today_cost - FROM usage_logs - WHERE user_id = ANY($1) AND created_at >= $2 - GROUP BY user_id - ` - rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(userIDs), today) - if err != nil { - return nil, err - } - for rows.Next() { - var userID int64 - var total float64 - if err := rows.Scan(&userID, &total); err != nil { - _ = rows.Close() - return nil, err - } - if stats, ok := result[userID]; ok { - stats.TodayActualCost = total + stats.TodayActualCost = todayTotal } } if err := rows.Close(); err != nil { @@ -1443,65 +1588,53 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs // BatchAPIKeyUsageStats represents usage stats for a single API key type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats -// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys -func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) { +// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys within a time range. +// If startTime is zero, defaults to 30 days ago. +func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) { result := make(map[int64]*BatchAPIKeyUsageStats) - if len(apiKeyIDs) == 0 { + normalizedAPIKeyIDs := normalizePositiveInt64IDs(apiKeyIDs) + if len(normalizedAPIKeyIDs) == 0 { return result, nil } - for _, id := range apiKeyIDs { + // 默认最近 30 天 + if startTime.IsZero() { + startTime = time.Now().AddDate(0, 0, -30) + } + if endTime.IsZero() { + endTime = time.Now() + } + + for _, id := range normalizedAPIKeyIDs { result[id] = &BatchAPIKeyUsageStats{APIKeyID: id} } query := ` - SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost + SELECT + api_key_id, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost FROM usage_logs WHERE api_key_id = ANY($1) + AND created_at >= LEAST($2, $4) GROUP BY api_key_id ` - rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs)) + today := timezone.Today() + rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedAPIKeyIDs), startTime, endTime, today) if err != nil { return nil, err } for rows.Next() { var apiKeyID int64 var total float64 - if err := rows.Scan(&apiKeyID, &total); err != nil { + var todayTotal float64 + if err := rows.Scan(&apiKeyID, &total, &todayTotal); err != nil { _ = rows.Close() return nil, err } if stats, ok := result[apiKeyID]; ok { stats.TotalActualCost = total - } - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - - today := timezone.Today() - todayQuery := ` - SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as today_cost - FROM usage_logs - WHERE api_key_id = ANY($1) AND created_at >= $2 - GROUP BY api_key_id - ` - rows, err = r.sql.QueryContext(ctx, todayQuery, pq.Array(apiKeyIDs), today) - if err != nil { - return nil, err - } - for rows.Next() { - var apiKeyID int64 - var total float64 - if err := rows.Scan(&apiKeyID, &total); err != nil { - _ = rows.Close() - return nil, err - } - if stats, ok := result[apiKeyID]; ok { - stats.TodayActualCost = total + stats.TodayActualCost = todayTotal } } if err := rows.Close(); err != nil { @@ -1515,12 +1648,16 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe } // GetUsageTrendWithFilters returns usage trend data with optional filters -func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) { - dateFormat := "YYYY-MM-DD" - if granularity == "hour" { - dateFormat = "YYYY-MM-DD HH24:00" +func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []TrendDataPoint, err error) { + if shouldUsePreaggregatedTrend(granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) { + aggregated, aggregatedErr := r.getUsageTrendFromAggregates(ctx, startTime, endTime, granularity) + if aggregatedErr == nil && len(aggregated) > 0 { + return aggregated, nil + } } + dateFormat := safeDateFormat(granularity) + query := fmt.Sprintf(` SELECT TO_CHAR(created_at, '%s') as date, @@ -1556,10 +1693,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start query += fmt.Sprintf(" AND model = $%d", len(args)+1) args = append(args, model) } - if stream != nil { - query += fmt.Sprintf(" AND stream = $%d", len(args)+1) - args = append(args, *stream) - } + query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) if billingType != nil { query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) args = append(args, int16(*billingType)) @@ -1586,8 +1720,80 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start return results, nil } +func shouldUsePreaggregatedTrend(granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) bool { + if granularity != "day" && granularity != "hour" { + return false + } + return userID == 0 && + apiKeyID == 0 && + accountID == 0 && + groupID == 0 && + model == "" && + requestType == nil && + stream == nil && + billingType == nil +} + +func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) { + dateFormat := safeDateFormat(granularity) + query := "" + args := []any{startTime, endTime} + + switch granularity { + case "hour": + query = fmt.Sprintf(` + SELECT + TO_CHAR(bucket_start, '%s') as date, + total_requests as requests, + input_tokens, + output_tokens, + (cache_creation_tokens + cache_read_tokens) as cache_tokens, + (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens, + total_cost as cost, + actual_cost + FROM usage_dashboard_hourly + WHERE bucket_start >= $1 AND bucket_start < $2 + ORDER BY bucket_start ASC + `, dateFormat) + case "day": + query = fmt.Sprintf(` + SELECT + TO_CHAR(bucket_date::timestamp, '%s') as date, + total_requests as requests, + input_tokens, + output_tokens, + (cache_creation_tokens + cache_read_tokens) as cache_tokens, + (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) as total_tokens, + total_cost as cost, + actual_cost + FROM usage_dashboard_daily + WHERE bucket_date >= $1::date AND bucket_date < $2::date + ORDER BY bucket_date ASC + `, dateFormat) + default: + return nil, nil + } + + rows, err := r.sql.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + results = nil + } + }() + + results, err = scanTrendRows(rows) + if err != nil { + return nil, err + } + return results, nil +} + // GetModelStatsWithFilters returns model statistics with optional filters -func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) (results []ModelStat, err error) { +func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) { actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" // 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。 if accountID > 0 && userID == 0 && apiKeyID == 0 { @@ -1624,10 +1830,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) args = append(args, groupID) } - if stream != nil { - query += fmt.Sprintf(" AND stream = $%d", len(args)+1) - args = append(args, *stream) - } + query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) if billingType != nil { query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) args = append(args, int16(*billingType)) @@ -1654,6 +1857,77 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start return results, nil } +// GetGroupStatsWithFilters returns group usage statistics with optional filters +func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []usagestats.GroupStat, err error) { + query := ` + SELECT + COALESCE(ul.group_id, 0) as group_id, + COALESCE(g.name, '') as group_name, + COUNT(*) as requests, + COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens, + COALESCE(SUM(ul.total_cost), 0) as cost, + COALESCE(SUM(ul.actual_cost), 0) as actual_cost + FROM usage_logs ul + LEFT JOIN groups g ON g.id = ul.group_id + WHERE ul.created_at >= $1 AND ul.created_at < $2 + ` + + args := []any{startTime, endTime} + if userID > 0 { + query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1) + args = append(args, userID) + } + if apiKeyID > 0 { + query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1) + args = append(args, apiKeyID) + } + if accountID > 0 { + query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1) + args = append(args, accountID) + } + if groupID > 0 { + query += fmt.Sprintf(" AND ul.group_id = $%d", len(args)+1) + args = append(args, groupID) + } + query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) + if billingType != nil { + query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1) + args = append(args, int16(*billingType)) + } + query += " GROUP BY ul.group_id, g.name ORDER BY total_tokens DESC" + + rows, err := r.sql.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + results = nil + } + }() + + results = make([]usagestats.GroupStat, 0) + for rows.Next() { + var row usagestats.GroupStat + if err := rows.Scan( + &row.GroupID, + &row.GroupName, + &row.Requests, + &row.TotalTokens, + &row.Cost, + &row.ActualCost, + ); err != nil { + return nil, err + } + results = append(results, row) + } + if err := rows.Err(); err != nil { + return nil, err + } + return results, nil +} + // GetGlobalStats gets usage statistics for all users within a time range func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) { query := ` @@ -1714,10 +1988,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1)) args = append(args, filters.Model) } - if filters.Stream != nil { - conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1)) - args = append(args, *filters.Stream) - } + conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream) if filters.BillingType != nil { conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) args = append(args, int16(*filters.BillingType)) @@ -1937,7 +2208,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID } } - models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil) + models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil, nil) if err != nil { models = []ModelStat{} } @@ -1968,6 +2239,35 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh return logs, paginationResultFromTotal(total, params), nil } +func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { + limit := params.Limit() + offset := params.Offset() + + limitPos := len(args) + 1 + offsetPos := len(args) + 2 + listArgs := append(append([]any{}, args...), limit+1, offset) + query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos) + + logs, err := r.queryUsageLogs(ctx, query, listArgs...) + if err != nil { + return nil, nil, err + } + + hasMore := false + if len(logs) > limit { + hasMore = true + logs = logs[:limit] + } + + total := int64(offset) + int64(len(logs)) + if hasMore { + // 只保证“还有下一页”,避免对超大表做全量 COUNT(*)。 + total = int64(offset) + int64(limit) + 1 + } + + return logs, paginationResultFromTotal(total, params), nil +} + func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) { rows, err := r.sql.QueryContext(ctx, query, args...) if err != nil { @@ -2187,14 +2487,18 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e rateMultiplier float64 accountRateMultiplier sql.NullFloat64 billingType int16 + requestTypeRaw int16 stream bool + openaiWSMode bool durationMs sql.NullInt64 firstTokenMs sql.NullInt64 userAgent sql.NullString ipAddress sql.NullString imageCount int imageSize sql.NullString + mediaType sql.NullString reasoningEffort sql.NullString + cacheTTLOverridden bool createdAt time.Time ) @@ -2222,14 +2526,18 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &rateMultiplier, &accountRateMultiplier, &billingType, + &requestTypeRaw, &stream, + &openaiWSMode, &durationMs, &firstTokenMs, &userAgent, &ipAddress, &imageCount, &imageSize, + &mediaType, &reasoningEffort, + &cacheTTLOverridden, &createdAt, ); err != nil { return nil, err @@ -2256,10 +2564,16 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e RateMultiplier: rateMultiplier, AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier), BillingType: int8(billingType), - Stream: stream, + RequestType: service.RequestTypeFromInt16(requestTypeRaw), ImageCount: imageCount, + CacheTTLOverridden: cacheTTLOverridden, CreatedAt: createdAt, } + // 先回填 legacy 字段,再基于 legacy + request_type 计算最终请求类型,保证历史数据兼容。 + log.Stream = stream + log.OpenAIWSMode = openaiWSMode + log.RequestType = log.EffectiveRequestType() + log.Stream, log.OpenAIWSMode = service.ApplyLegacyRequestFields(log.RequestType, stream, openaiWSMode) if requestID.Valid { log.RequestID = requestID.String @@ -2289,6 +2603,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if imageSize.Valid { log.ImageSize = &imageSize.String } + if mediaType.Valid { + log.MediaType = &mediaType.String + } if reasoningEffort.Valid { log.ReasoningEffort = &reasoningEffort.String } @@ -2350,6 +2667,50 @@ func buildWhere(conditions []string) string { return "WHERE " + strings.Join(conditions, " AND ") } +func appendRequestTypeOrStreamWhereCondition(conditions []string, args []any, requestType *int16, stream *bool) ([]string, []any) { + if requestType != nil { + condition, conditionArgs := buildRequestTypeFilterCondition(len(args)+1, *requestType) + conditions = append(conditions, condition) + args = append(args, conditionArgs...) + return conditions, args + } + if stream != nil { + conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1)) + args = append(args, *stream) + } + return conditions, args +} + +func appendRequestTypeOrStreamQueryFilter(query string, args []any, requestType *int16, stream *bool) (string, []any) { + if requestType != nil { + condition, conditionArgs := buildRequestTypeFilterCondition(len(args)+1, *requestType) + query += " AND " + condition + args = append(args, conditionArgs...) + return query, args + } + if stream != nil { + query += fmt.Sprintf(" AND stream = $%d", len(args)+1) + args = append(args, *stream) + } + return query, args +} + +// buildRequestTypeFilterCondition 在 request_type 过滤时兼容 legacy 字段,避免历史数据漏查。 +func buildRequestTypeFilterCondition(startArgIndex int, requestType int16) (string, []any) { + normalized := service.RequestTypeFromInt16(requestType) + requestTypeArg := int16(normalized) + switch normalized { + case service.RequestTypeSync: + return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND stream = FALSE AND openai_ws_mode = FALSE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg} + case service.RequestTypeStream: + return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND stream = TRUE AND openai_ws_mode = FALSE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg} + case service.RequestTypeWSV2: + return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND openai_ws_mode = TRUE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg} + default: + return fmt.Sprintf("request_type = $%d", startArgIndex), []any{requestTypeArg} + } +} + func nullInt64(v *int64) sql.NullInt64 { if v == nil { return sql.NullInt64{} diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index eb220f22..4d50f7de 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -130,6 +130,62 @@ func (s *UsageLogRepoSuite) TestGetByID_ReturnsAccountRateMultiplier() { s.Require().InEpsilon(0.5, *got.AccountRateMultiplier, 0.0001) } +func (s *UsageLogRepoSuite) TestGetByID_ReturnsOpenAIWSMode() { + user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-ws@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-ws", Name: "k"}) + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-ws"}) + + log := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.New().String(), + Model: "gpt-5.3-codex", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 1.0, + ActualCost: 1.0, + OpenAIWSMode: true, + CreatedAt: timezone.Today().Add(3 * time.Hour), + } + _, err := s.repo.Create(s.ctx, log) + s.Require().NoError(err) + + got, err := s.repo.GetByID(s.ctx, log.ID) + s.Require().NoError(err) + s.Require().True(got.OpenAIWSMode) +} + +func (s *UsageLogRepoSuite) TestGetByID_ReturnsRequestTypeAndLegacyFallback() { + user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-request-type@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-request-type", Name: "k"}) + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-request-type"}) + + log := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.New().String(), + Model: "gpt-5.3-codex", + RequestType: service.RequestTypeWSV2, + Stream: true, + OpenAIWSMode: false, + InputTokens: 10, + OutputTokens: 20, + TotalCost: 1.0, + ActualCost: 1.0, + CreatedAt: timezone.Today().Add(4 * time.Hour), + } + _, err := s.repo.Create(s.ctx, log) + s.Require().NoError(err) + + got, err := s.repo.GetByID(s.ctx, log.ID) + s.Require().NoError(err) + s.Require().Equal(service.RequestTypeWSV2, got.RequestType) + s.Require().True(got.Stream) + s.Require().True(got.OpenAIWSMode) +} + // --- Delete --- func (s *UsageLogRepoSuite) TestDelete() { @@ -648,7 +704,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now()) - stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID}) + stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID}, time.Time{}, time.Time{}) s.Require().NoError(err, "GetBatchUserUsageStats") s.Require().Len(stats, 2) s.Require().NotNil(stats[user1.ID]) @@ -656,7 +712,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { } func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() { - stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{}) + stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{}) s.Require().NoError(err) s.Require().Empty(stats) } @@ -672,13 +728,13 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() { s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now()) - stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}) + stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}, time.Time{}, time.Time{}) s.Require().NoError(err, "GetBatchAPIKeyUsageStats") s.Require().Len(stats, 2) } func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() { - stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{}) + stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{}) s.Require().NoError(err) s.Require().Empty(stats) } @@ -944,17 +1000,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { endTime := base.Add(48 * time.Hour) // Test with user filter - trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil) + trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters user filter") s.Require().Len(trend, 2) // Test with apiKey filter - trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil) + trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter") s.Require().Len(trend, 2) // Test with both filters - trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil) + trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters both filters") s.Require().Len(trend, 2) } @@ -971,7 +1027,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { startTime := base.Add(-1 * time.Hour) endTime := base.Add(3 * time.Hour) - trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil) + trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters hourly") s.Require().Len(trend, 2) } @@ -1017,17 +1073,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { endTime := base.Add(2 * time.Hour) // Test with user filter - stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil) + stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil, nil) s.Require().NoError(err, "GetModelStatsWithFilters user filter") s.Require().Len(stats, 2) // Test with apiKey filter - stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil) + stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil, nil) s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter") s.Require().Len(stats, 2) // Test with account filter - stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil) + stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil, nil) s.Require().NoError(err, "GetModelStatsWithFilters account filter") s.Require().Len(stats, 2) } diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go new file mode 100644 index 00000000..54eb81e1 --- /dev/null +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -0,0 +1,328 @@ +package repository + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + createdAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + log := &service.UsageLog{ + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-1", + Model: "gpt-5", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 1, + ActualCost: 1, + BillingType: service.BillingTypeBalance, + RequestType: service.RequestTypeWSV2, + Stream: false, + OpenAIWSMode: false, + CreatedAt: createdAt, + } + + mock.ExpectQuery("INSERT INTO usage_logs"). + WithArgs( + log.UserID, + log.APIKeyID, + log.AccountID, + log.RequestID, + log.Model, + sqlmock.AnyArg(), // group_id + sqlmock.AnyArg(), // subscription_id + log.InputTokens, + log.OutputTokens, + log.CacheCreationTokens, + log.CacheReadTokens, + log.CacheCreation5mTokens, + log.CacheCreation1hTokens, + log.InputCost, + log.OutputCost, + log.CacheCreationCost, + log.CacheReadCost, + log.TotalCost, + log.ActualCost, + log.RateMultiplier, + log.AccountRateMultiplier, + log.BillingType, + int16(service.RequestTypeWSV2), + true, + true, + sqlmock.AnyArg(), // duration_ms + sqlmock.AnyArg(), // first_token_ms + sqlmock.AnyArg(), // user_agent + sqlmock.AnyArg(), // ip_address + log.ImageCount, + sqlmock.AnyArg(), // image_size + sqlmock.AnyArg(), // media_type + sqlmock.AnyArg(), // reasoning_effort + log.CacheTTLOverridden, + createdAt, + ). + WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt)) + + inserted, err := repo.Create(context.Background(), log) + require.NoError(t, err) + require.True(t, inserted) + require.Equal(t, int64(99), log.ID) + require.Equal(t, service.RequestTypeWSV2, log.RequestType) + require.True(t, log.Stream) + require.True(t, log.OpenAIWSMode) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + requestType := int16(service.RequestTypeWSV2) + stream := false + filters := usagestats.UsageLogFilters{ + RequestType: &requestType, + Stream: &stream, + ExactTotal: true, + } + + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)"). + WithArgs(requestType). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0))) + mock.ExpectQuery("SELECT .* FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\) ORDER BY id DESC LIMIT \\$2 OFFSET \\$3"). + WithArgs(requestType, 20, 0). + WillReturnRows(sqlmock.NewRows([]string{"id"})) + + logs, page, err := repo.ListWithFilters(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}, filters) + require.NoError(t, err) + require.Empty(t, logs) + require.NotNil(t, page) + require.Equal(t, int64(0), page.Total) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageLogRepositoryGetUsageTrendWithFiltersRequestTypePriority(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + requestType := int16(service.RequestTypeStream) + stream := true + + mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)"). + WithArgs(start, end, requestType). + WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_tokens", "total_tokens", "cost", "actual_cost"})) + + trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil) + require.NoError(t, err) + require.Empty(t, trend) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + requestType := int16(service.RequestTypeWSV2) + stream := false + + mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)"). + WithArgs(start, end, requestType). + WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "total_tokens", "cost", "actual_cost"})) + + stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil) + require.NoError(t, err) + require.Empty(t, stats) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + requestType := int16(service.RequestTypeSync) + stream := true + filters := usagestats.UsageLogFilters{ + RequestType: &requestType, + Stream: &stream, + } + + mock.ExpectQuery("FROM usage_logs\\s+WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND stream = FALSE AND openai_ws_mode = FALSE\\)\\)"). + WithArgs(requestType). + WillReturnRows(sqlmock.NewRows([]string{ + "total_requests", + "total_input_tokens", + "total_output_tokens", + "total_cache_tokens", + "total_cost", + "total_actual_cost", + "total_account_cost", + "avg_duration_ms", + }).AddRow(int64(1), int64(2), int64(3), int64(4), 1.2, 1.0, 1.2, 20.0)) + + stats, err := repo.GetStatsWithFilters(context.Background(), filters) + require.NoError(t, err) + require.Equal(t, int64(1), stats.TotalRequests) + require.Equal(t, int64(9), stats.TotalTokens) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) { + tests := []struct { + name string + request int16 + wantWhere string + wantArg int16 + }{ + { + name: "sync_with_legacy_fallback", + request: int16(service.RequestTypeSync), + wantWhere: "(request_type = $3 OR (request_type = 0 AND stream = FALSE AND openai_ws_mode = FALSE))", + wantArg: int16(service.RequestTypeSync), + }, + { + name: "stream_with_legacy_fallback", + request: int16(service.RequestTypeStream), + wantWhere: "(request_type = $3 OR (request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE))", + wantArg: int16(service.RequestTypeStream), + }, + { + name: "ws_v2_with_legacy_fallback", + request: int16(service.RequestTypeWSV2), + wantWhere: "(request_type = $3 OR (request_type = 0 AND openai_ws_mode = TRUE))", + wantArg: int16(service.RequestTypeWSV2), + }, + { + name: "invalid_request_type_normalized_to_unknown", + request: int16(99), + wantWhere: "request_type = $3", + wantArg: int16(service.RequestTypeUnknown), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + where, args := buildRequestTypeFilterCondition(3, tt.request) + require.Equal(t, tt.wantWhere, where) + require.Equal(t, []any{tt.wantArg}, args) + }) + } +} + +type usageLogScannerStub struct { + values []any +} + +func (s usageLogScannerStub) Scan(dest ...any) error { + if len(dest) != len(s.values) { + return fmt.Errorf("scan arg count mismatch: got %d want %d", len(dest), len(s.values)) + } + for i := range dest { + dv := reflect.ValueOf(dest[i]) + if dv.Kind() != reflect.Ptr { + return fmt.Errorf("dest[%d] is not pointer", i) + } + dv.Elem().Set(reflect.ValueOf(s.values[i])) + } + return nil +} + +func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { + t.Run("request_type_ws_v2_overrides_legacy", func(t *testing.T) { + now := time.Now().UTC() + log, err := scanUsageLog(usageLogScannerStub{values: []any{ + int64(1), // id + int64(10), // user_id + int64(20), // api_key_id + int64(30), // account_id + sql.NullString{Valid: true, String: "req-1"}, + "gpt-5", // model + sql.NullInt64{}, // group_id + sql.NullInt64{}, // subscription_id + 1, // input_tokens + 2, // output_tokens + 3, // cache_creation_tokens + 4, // cache_read_tokens + 5, // cache_creation_5m_tokens + 6, // cache_creation_1h_tokens + 0.1, // input_cost + 0.2, // output_cost + 0.3, // cache_creation_cost + 0.4, // cache_read_cost + 1.0, // total_cost + 0.9, // actual_cost + 1.0, // rate_multiplier + sql.NullFloat64{}, // account_rate_multiplier + int16(service.BillingTypeBalance), + int16(service.RequestTypeWSV2), + false, // legacy stream + false, // legacy openai ws + sql.NullInt64{}, + sql.NullInt64{}, + sql.NullString{}, + sql.NullString{}, + 0, + sql.NullString{}, + sql.NullString{}, + sql.NullString{}, + false, + now, + }}) + require.NoError(t, err) + require.Equal(t, service.RequestTypeWSV2, log.RequestType) + require.True(t, log.Stream) + require.True(t, log.OpenAIWSMode) + }) + + t.Run("request_type_unknown_falls_back_to_legacy", func(t *testing.T) { + now := time.Now().UTC() + log, err := scanUsageLog(usageLogScannerStub{values: []any{ + int64(2), + int64(11), + int64(21), + int64(31), + sql.NullString{Valid: true, String: "req-2"}, + "gpt-5", + sql.NullInt64{}, + sql.NullInt64{}, + 1, 2, 3, 4, 5, 6, + 0.1, 0.2, 0.3, 0.4, 1.0, 0.9, + 1.0, + sql.NullFloat64{}, + int16(service.BillingTypeBalance), + int16(service.RequestTypeUnknown), + true, + false, + sql.NullInt64{}, + sql.NullInt64{}, + sql.NullString{}, + sql.NullString{}, + 0, + sql.NullString{}, + sql.NullString{}, + sql.NullString{}, + false, + now, + }}) + require.NoError(t, err) + require.Equal(t, service.RequestTypeStream, log.RequestType) + require.True(t, log.Stream) + require.False(t, log.OpenAIWSMode) + }) +} diff --git a/backend/internal/repository/usage_log_repo_unit_test.go b/backend/internal/repository/usage_log_repo_unit_test.go new file mode 100644 index 00000000..d0e14ffd --- /dev/null +++ b/backend/internal/repository/usage_log_repo_unit_test.go @@ -0,0 +1,41 @@ +//go:build unit + +package repository + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSafeDateFormat(t *testing.T) { + tests := []struct { + name string + granularity string + expected string + }{ + // 合法值 + {"hour", "hour", "YYYY-MM-DD HH24:00"}, + {"day", "day", "YYYY-MM-DD"}, + {"week", "week", "IYYY-IW"}, + {"month", "month", "YYYY-MM"}, + + // 非法值回退到默认 + {"空字符串", "", "YYYY-MM-DD"}, + {"未知粒度 year", "year", "YYYY-MM-DD"}, + {"未知粒度 minute", "minute", "YYYY-MM-DD"}, + + // 恶意字符串 + {"SQL 注入尝试", "'; DROP TABLE users; --", "YYYY-MM-DD"}, + {"带引号", "day'", "YYYY-MM-DD"}, + {"带括号", "day)", "YYYY-MM-DD"}, + {"Unicode", "日", "YYYY-MM-DD"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := safeDateFormat(tc.granularity) + require.Equal(t, tc.expected, got, "safeDateFormat(%q)", tc.granularity) + }) + } +} diff --git a/backend/internal/repository/user_group_rate_repo.go b/backend/internal/repository/user_group_rate_repo.go index eb65403b..e3b11096 100644 --- a/backend/internal/repository/user_group_rate_repo.go +++ b/backend/internal/repository/user_group_rate_repo.go @@ -6,6 +6,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/lib/pq" ) type userGroupRateRepository struct { @@ -41,6 +42,59 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) return result, nil } +// GetByUserIDs 批量获取多个用户的专属分组倍率。 +// 返回结构:map[userID]map[groupID]rate +func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) { + result := make(map[int64]map[int64]float64, len(userIDs)) + if len(userIDs) == 0 { + return result, nil + } + + uniqueIDs := make([]int64, 0, len(userIDs)) + seen := make(map[int64]struct{}, len(userIDs)) + for _, userID := range userIDs { + if userID <= 0 { + continue + } + if _, exists := seen[userID]; exists { + continue + } + seen[userID] = struct{}{} + uniqueIDs = append(uniqueIDs, userID) + result[userID] = make(map[int64]float64) + } + if len(uniqueIDs) == 0 { + return result, nil + } + + rows, err := r.sql.QueryContext(ctx, ` + SELECT user_id, group_id, rate_multiplier + FROM user_group_rate_multipliers + WHERE user_id = ANY($1) + `, pq.Array(uniqueIDs)) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var userID int64 + var groupID int64 + var rate float64 + if err := rows.Scan(&userID, &groupID, &rate); err != nil { + return nil, err + } + if _, ok := result[userID]; !ok { + result[userID] = make(map[int64]float64) + } + result[userID][groupID] = rate + } + if err := rows.Err(); err != nil { + return nil, err + } + return result, nil +} + // GetByUserAndGroup 获取用户在特定分组的专属倍率 func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2` @@ -65,33 +119,43 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID // 分离需要删除和需要 upsert 的记录 var toDelete []int64 - toUpsert := make(map[int64]float64) + upsertGroupIDs := make([]int64, 0, len(rates)) + upsertRates := make([]float64, 0, len(rates)) for groupID, rate := range rates { if rate == nil { toDelete = append(toDelete, groupID) } else { - toUpsert[groupID] = *rate + upsertGroupIDs = append(upsertGroupIDs, groupID) + upsertRates = append(upsertRates, *rate) } } // 删除指定的记录 - for _, groupID := range toDelete { - _, err := r.sql.ExecContext(ctx, - `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`, - userID, groupID) - if err != nil { + if len(toDelete) > 0 { + if _, err := r.sql.ExecContext(ctx, + `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`, + userID, pq.Array(toDelete)); err != nil { return err } } // Upsert 记录 now := time.Now() - for groupID, rate := range toUpsert { + if len(upsertGroupIDs) > 0 { _, err := r.sql.ExecContext(ctx, ` INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at) - VALUES ($1, $2, $3, $4, $4) - ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4 - `, userID, groupID, rate, now) + SELECT + $1::bigint, + data.group_id, + data.rate_multiplier, + $2::timestamptz, + $2::timestamptz + FROM unnest($3::bigint[], $4::double precision[]) AS data(group_id, rate_multiplier) + ON CONFLICT (user_id, group_id) + DO UPDATE SET + rate_multiplier = EXCLUDED.rate_multiplier, + updated_at = EXCLUDED.updated_at + `, userID, now, pq.Array(upsertGroupIDs), pq.Array(upsertRates)) if err != nil { return err } diff --git a/backend/internal/repository/user_msg_queue_cache.go b/backend/internal/repository/user_msg_queue_cache.go new file mode 100644 index 00000000..bb3ee698 --- /dev/null +++ b/backend/internal/repository/user_msg_queue_cache.go @@ -0,0 +1,186 @@ +package repository + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +// Redis Key 模式(使用 hash tag 确保 Redis Cluster 下同一 accountID 的 key 落入同一 slot) +// 格式: umq:{accountID}:lock / umq:{accountID}:last +const ( + umqKeyPrefix = "umq:" + umqLockSuffix = ":lock" // STRING (requestID), PX lockTtlMs + umqLastSuffix = ":last" // STRING (毫秒时间戳), EX 60s +) + +// Lua 脚本:原子获取串行锁(SET NX PX + 重入安全) +var acquireLockScript = redis.NewScript(` +local cur = redis.call('GET', KEYS[1]) +if cur == ARGV[1] then + redis.call('PEXPIRE', KEYS[1], tonumber(ARGV[2])) + return 1 +end +if cur ~= false then return 0 end +redis.call('SET', KEYS[1], ARGV[1], 'PX', tonumber(ARGV[2])) +return 1 +`) + +// Lua 脚本:原子释放锁 + 记录完成时间(使用 Redis TIME 避免时钟偏差) +var releaseLockScript = redis.NewScript(` +local cur = redis.call('GET', KEYS[1]) +if cur == ARGV[1] then + redis.call('DEL', KEYS[1]) + local t = redis.call('TIME') + local ms = tonumber(t[1])*1000 + math.floor(tonumber(t[2])/1000) + redis.call('SET', KEYS[2], ms, 'EX', 60) + return 1 +end +return 0 +`) + +// Lua 脚本:原子清理孤儿锁(仅在 PTTL == -1 时删除,避免 TOCTOU 竞态误删合法锁) +var forceReleaseLockScript = redis.NewScript(` +local pttl = redis.call('PTTL', KEYS[1]) +if pttl == -1 then + redis.call('DEL', KEYS[1]) + return 1 +end +return 0 +`) + +type userMsgQueueCache struct { + rdb *redis.Client +} + +// NewUserMsgQueueCache 创建用户消息队列缓存 +func NewUserMsgQueueCache(rdb *redis.Client) service.UserMsgQueueCache { + return &userMsgQueueCache{rdb: rdb} +} + +func umqLockKey(accountID int64) string { + // 格式: umq:{123}:lock — 花括号确保 Redis Cluster hash tag 生效 + return umqKeyPrefix + "{" + strconv.FormatInt(accountID, 10) + "}" + umqLockSuffix +} + +func umqLastKey(accountID int64) string { + // 格式: umq:{123}:last — 与 lockKey 同一 hash slot + return umqKeyPrefix + "{" + strconv.FormatInt(accountID, 10) + "}" + umqLastSuffix +} + +// umqScanPattern 用于 SCAN 扫描锁 key +func umqScanPattern() string { + return umqKeyPrefix + "{*}" + umqLockSuffix +} + +// AcquireLock 尝试获取账号级串行锁 +func (c *userMsgQueueCache) AcquireLock(ctx context.Context, accountID int64, requestID string, lockTtlMs int) (bool, error) { + key := umqLockKey(accountID) + result, err := acquireLockScript.Run(ctx, c.rdb, []string{key}, requestID, lockTtlMs).Int() + if err != nil { + return false, fmt.Errorf("umq acquire lock: %w", err) + } + return result == 1, nil +} + +// ReleaseLock 释放锁并记录完成时间 +func (c *userMsgQueueCache) ReleaseLock(ctx context.Context, accountID int64, requestID string) (bool, error) { + lockKey := umqLockKey(accountID) + lastKey := umqLastKey(accountID) + result, err := releaseLockScript.Run(ctx, c.rdb, []string{lockKey, lastKey}, requestID).Int() + if err != nil { + return false, fmt.Errorf("umq release lock: %w", err) + } + return result == 1, nil +} + +// GetLastCompletedMs 获取上次完成时间(毫秒时间戳) +func (c *userMsgQueueCache) GetLastCompletedMs(ctx context.Context, accountID int64) (int64, error) { + key := umqLastKey(accountID) + val, err := c.rdb.Get(ctx, key).Result() + if errors.Is(err, redis.Nil) { + return 0, nil + } + if err != nil { + return 0, fmt.Errorf("umq get last completed: %w", err) + } + ms, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return 0, fmt.Errorf("umq parse last completed: %w", err) + } + return ms, nil +} + +// ForceReleaseLock 原子清理孤儿锁(仅在 PTTL == -1 时删除,防止 TOCTOU 竞态误删合法锁) +func (c *userMsgQueueCache) ForceReleaseLock(ctx context.Context, accountID int64) error { + key := umqLockKey(accountID) + _, err := forceReleaseLockScript.Run(ctx, c.rdb, []string{key}).Result() + if err != nil && !errors.Is(err, redis.Nil) { + return fmt.Errorf("umq force release lock: %w", err) + } + return nil +} + +// ScanLockKeys 扫描所有锁 key,仅返回 PTTL == -1(无过期时间)的孤儿锁 accountID 列表 +// 正常的锁都有 PX 过期时间,PTTL == -1 表示异常状态(如 Redis 故障恢复后丢失 TTL) +func (c *userMsgQueueCache) ScanLockKeys(ctx context.Context, maxCount int) ([]int64, error) { + var accountIDs []int64 + var cursor uint64 + pattern := umqScanPattern() + + for { + keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, 100).Result() + if err != nil { + return nil, fmt.Errorf("umq scan lock keys: %w", err) + } + for _, key := range keys { + // 检查 PTTL:只清理 PTTL == -1(无过期时间)的异常锁 + pttl, err := c.rdb.PTTL(ctx, key).Result() + if err != nil { + continue + } + // PTTL 返回值:-2 = key 不存在,-1 = 无过期时间,>0 = 剩余毫秒 + // go-redis 对哨兵值 -1/-2 不乘精度系数,直接返回 time.Duration(-1)/-2 + // 只删除 -1(无过期时间的异常锁),跳过正常持有的锁 + if pttl != time.Duration(-1) { + continue + } + + // 从 key 中提取 accountID: umq:{123}:lock → 提取 {} 内的数字 + openBrace := strings.IndexByte(key, '{') + closeBrace := strings.IndexByte(key, '}') + if openBrace < 0 || closeBrace <= openBrace+1 { + continue + } + idStr := key[openBrace+1 : closeBrace] + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + continue + } + accountIDs = append(accountIDs, id) + if len(accountIDs) >= maxCount { + return accountIDs, nil + } + } + cursor = nextCursor + if cursor == 0 { + break + } + } + return accountIDs, nil +} + +// GetCurrentTimeMs 通过 Redis TIME 命令获取当前服务器时间(毫秒),确保与锁记录的时间源一致 +func (c *userMsgQueueCache) GetCurrentTimeMs(ctx context.Context) (int64, error) { + t, err := c.rdb.Time(ctx).Result() + if err != nil { + return 0, fmt.Errorf("umq get redis time: %w", err) + } + return t.UnixMilli(), nil +} diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 17674291..b56aaaf9 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -61,6 +61,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error SetBalance(userIn.Balance). SetConcurrency(userIn.Concurrency). SetStatus(userIn.Status). + SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes). Save(ctx) if err != nil { return translatePersistenceError(err, nil, service.ErrEmailExists) @@ -143,6 +144,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error SetBalance(userIn.Balance). SetConcurrency(userIn.Concurrency). SetStatus(userIn.Status). + SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes). + SetSoraStorageUsedBytes(userIn.SoraStorageUsedBytes). Save(ctx) if err != nil { return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists) @@ -240,21 +243,24 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. userMap[u.ID] = &outUsers[len(outUsers)-1] } - // Batch load active subscriptions with groups to avoid N+1. - subs, err := r.client.UserSubscription.Query(). - Where( - usersubscription.UserIDIn(userIDs...), - usersubscription.StatusEQ(service.SubscriptionStatusActive), - ). - WithGroup(). - All(ctx) - if err != nil { - return nil, nil, err - } + shouldLoadSubscriptions := filters.IncludeSubscriptions == nil || *filters.IncludeSubscriptions + if shouldLoadSubscriptions { + // Batch load active subscriptions with groups to avoid N+1. + subs, err := r.client.UserSubscription.Query(). + Where( + usersubscription.UserIDIn(userIDs...), + usersubscription.StatusEQ(service.SubscriptionStatusActive), + ). + WithGroup(). + All(ctx) + if err != nil { + return nil, nil, err + } - for i := range subs { - if u, ok := userMap[subs[i].UserID]; ok { - u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i])) + for i := range subs { + if u, ok := userMap[subs[i].UserID]; ok { + u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i])) + } } } @@ -363,10 +369,79 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount return nil } +// AddSoraStorageUsageWithQuota 原子累加 Sora 存储用量,并在有配额时校验不超额。 +func (r *userRepository) AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) { + if deltaBytes <= 0 { + user, err := r.GetByID(ctx, userID) + if err != nil { + return 0, err + } + return user.SoraStorageUsedBytes, nil + } + var newUsed int64 + err := scanSingleRow(ctx, r.sql, ` + UPDATE users + SET sora_storage_used_bytes = sora_storage_used_bytes + $2 + WHERE id = $1 + AND ($3 = 0 OR sora_storage_used_bytes + $2 <= $3) + RETURNING sora_storage_used_bytes + `, []any{userID, deltaBytes, effectiveQuota}, &newUsed) + if err == nil { + return newUsed, nil + } + if errors.Is(err, sql.ErrNoRows) { + // 区分用户不存在和配额冲突 + exists, existsErr := r.client.User.Query().Where(dbuser.IDEQ(userID)).Exist(ctx) + if existsErr != nil { + return 0, existsErr + } + if !exists { + return 0, service.ErrUserNotFound + } + return 0, service.ErrSoraStorageQuotaExceeded + } + return 0, err +} + +// ReleaseSoraStorageUsageAtomic 原子释放 Sora 存储用量,并保证不低于 0。 +func (r *userRepository) ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) { + if deltaBytes <= 0 { + user, err := r.GetByID(ctx, userID) + if err != nil { + return 0, err + } + return user.SoraStorageUsedBytes, nil + } + var newUsed int64 + err := scanSingleRow(ctx, r.sql, ` + UPDATE users + SET sora_storage_used_bytes = GREATEST(sora_storage_used_bytes - $2, 0) + WHERE id = $1 + RETURNING sora_storage_used_bytes + `, []any{userID, deltaBytes}, &newUsed) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, service.ErrUserNotFound + } + return 0, err + } + return newUsed, nil +} + func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx) } +func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { + client := clientFromContext(ctx, r.client) + return client.UserAllowedGroup.Create(). + SetUserID(userID). + SetGroupID(groupID). + OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID). + DoNothing(). + Exec(ctx) +} + func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { // 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。 affected, err := r.client.UserAllowedGroup.Delete(). diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 3aed9d9c..2e35e0a0 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -28,13 +28,13 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc // ProvideGitHubReleaseClient 创建 GitHub Release 客户端 // 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient { - return NewGitHubReleaseClient(cfg.Update.ProxyURL) + return NewGitHubReleaseClient(cfg.Update.ProxyURL, cfg.Security.ProxyFallback.AllowDirectOnError) } // ProvidePricingRemoteClient 创建定价数据远程客户端 // 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub 上的定价数据 func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient { - return NewPricingRemoteClient(cfg.Update.ProxyURL) + return NewPricingRemoteClient(cfg.Update.ProxyURL, cfg.Security.ProxyFallback.AllowDirectOnError) } // ProvideSessionLimitCache 创建会话限制缓存 @@ -53,12 +53,14 @@ var ProviderSet = wire.NewSet( NewAPIKeyRepository, NewGroupRepository, NewAccountRepository, + NewSoraAccountRepository, // Sora 账号扩展表仓储 NewProxyRepository, NewRedeemCodeRepository, NewPromoCodeRepository, NewAnnouncementRepository, NewAnnouncementReadRepository, NewUsageLogRepository, + NewIdempotencyRepository, NewUsageCleanupRepository, NewDashboardAggregationRepository, NewSettingRepository, @@ -77,6 +79,8 @@ var ProviderSet = wire.NewSet( NewTimeoutCounterCache, ProvideConcurrencyCache, ProvideSessionLimitCache, + NewRPMCache, + NewUserMsgQueueCache, NewDashboardCache, NewEmailCache, NewIdentityCache, @@ -104,6 +108,7 @@ var ProviderSet = wire.NewSet( NewOpenAIOAuthClient, NewGeminiOAuthClient, NewGeminiCliCodeAssistClient, + NewGeminiDriveClient, ProvideEnt, ProvideSQLDB, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index e6bb4f53..40b2d592 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -83,8 +83,18 @@ func TestAPIContracts(t *testing.T) { "status": "active", "ip_whitelist": null, "ip_blacklist": null, + "last_used_at": null, "quota": 0, "quota_used": 0, + "rate_limit_5h": 0, + "rate_limit_1d": 0, + "rate_limit_7d": 0, + "usage_5h": 0, + "usage_1d": 0, + "usage_7d": 0, + "window_5h_start": null, + "window_1d_start": null, + "window_7d_start": null, "expires_at": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" @@ -122,8 +132,18 @@ func TestAPIContracts(t *testing.T) { "status": "active", "ip_whitelist": null, "ip_blacklist": null, + "last_used_at": null, "quota": 0, "quota_used": 0, + "rate_limit_5h": 0, + "rate_limit_1d": 0, + "rate_limit_7d": 0, + "usage_5h": 0, + "usage_1d": 0, + "usage_7d": 0, + "window_5h_start": null, + "window_1d_start": null, + "window_7d_start": null, "expires_at": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" @@ -184,7 +204,12 @@ func TestAPIContracts(t *testing.T) { "image_price_1k": null, "image_price_2k": null, "image_price_4k": null, - "claude_code_only": false, + "sora_image_price_360": null, + "sora_image_price_540": null, + "sora_storage_quota_bytes": 0, + "sora_video_price_per_request": null, + "sora_video_price_per_request_hd": null, + "claude_code_only": false, "fallback_group_id": null, "fallback_group_id_on_invalid_request": null, "created_at": "2025-01-02T03:04:05Z", @@ -378,10 +403,12 @@ func TestAPIContracts(t *testing.T) { "user_id": 1, "api_key_id": 100, "account_id": 200, - "request_id": "req_123", - "model": "claude-3", - "group_id": null, - "subscription_id": null, + "request_id": "req_123", + "model": "claude-3", + "request_type": "stream", + "openai_ws_mode": false, + "group_id": null, + "subscription_id": null, "input_tokens": 10, "output_tokens": 20, "cache_creation_tokens": 1, @@ -401,6 +428,8 @@ func TestAPIContracts(t *testing.T) { "first_token_ms": 50, "image_count": 0, "image_size": null, + "media_type": null, + "cache_ttl_overridden": false, "created_at": "2025-01-02T03:04:05Z", "user_agent": null } @@ -417,9 +446,10 @@ func TestAPIContracts(t *testing.T) { setup: func(t *testing.T, deps *contractDeps) { t.Helper() deps.settingRepo.SetAll(map[string]string{ - service.SettingKeyRegistrationEnabled: "true", - service.SettingKeyEmailVerifyEnabled: "false", - service.SettingKeyPromoCodeEnabled: "true", + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyEmailVerifyEnabled: "false", + service.SettingKeyRegistrationEmailSuffixWhitelist: "[]", + service.SettingKeyPromoCodeEnabled: "true", service.SettingKeySMTPHost: "smtp.example.com", service.SettingKeySMTPPort: "587", @@ -458,6 +488,7 @@ func TestAPIContracts(t *testing.T) { "data": { "registration_enabled": true, "email_verify_enabled": false, + "registration_email_suffix_whitelist": [], "promo_code_enabled": true, "password_reset_enabled": false, "totp_enabled": false, @@ -488,18 +519,23 @@ func TestAPIContracts(t *testing.T) { "doc_url": "https://docs.example.com", "default_concurrency": 5, "default_balance": 1.25, + "default_subscriptions": [], "enable_model_fallback": false, "fallback_model_anthropic": "claude-3-5-sonnet-20241022", "fallback_model_antigravity": "gemini-2.5-pro", "fallback_model_gemini": "gemini-2.5-pro", - "fallback_model_openai": "gpt-4o", - "enable_identity_patch": true, - "identity_patch_prompt": "", - "invitation_code_enabled": false, - "home_content": "", + "fallback_model_openai": "gpt-4o", + "enable_identity_patch": true, + "identity_patch_prompt": "", + "sora_client_enabled": false, + "invitation_code_enabled": false, + "home_content": "", "hide_ccs_import_button": false, "purchase_subscription_enabled": false, - "purchase_subscription_url": "" + "purchase_subscription_url": "", + "min_claude_code_version": "", + "allow_ungrouped_key_scheduling": false, + "custom_menu_items": [] } }`, }, @@ -592,13 +628,13 @@ func newContractDeps(t *testing.T) *contractDeps { RunMode: config.RunModeStandard, } - userService := service.NewUserService(userRepo, nil) + userService := service.NewUserService(userRepo, nil, nil) apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg) usageRepo := newStubUsageLogRepo() usageService := service.NewUsageService(usageRepo, userRepo, nil, nil) - subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil) + subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil, nil, cfg) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil) @@ -607,12 +643,12 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil) + adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) - adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) - adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil) + adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) jwtAuth := func(c *gin.Context) { c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ @@ -767,6 +803,10 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID return 0, errors.New("not implemented") } +func (r *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { + return errors.New("not implemented") +} + func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { return errors.New("not implemented") } @@ -924,6 +964,10 @@ func (s *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID st return nil, errors.New("not implemented") } +func (s *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + func (s *stubAccountRepo) Update(ctx context.Context, account *service.Account) error { return errors.New("not implemented") } @@ -1004,6 +1048,14 @@ func (s *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte return nil, errors.New("not implemented") } +func (s *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return errors.New("not implemented") } @@ -1361,7 +1413,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error { return nil } -func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { +func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) { ids := make([]int64, 0, len(r.byID)) for id := range r.byID { if r.byID[id].UserID == userID { @@ -1461,6 +1513,30 @@ func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amoun return 0, errors.New("not implemented") } +func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { + key, ok := r.byID[id] + if !ok { + return service.ErrAPIKeyNotFound + } + ts := usedAt + key.LastUsedAt = &ts + key.UpdatedAt = usedAt + clone := *key + r.byID[id] = &clone + r.byKey[clone.Key] = &clone + return nil +} + +func (r *stubApiKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error { + return nil +} +func (r *stubApiKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error { + return nil +} +func (r *stubApiKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) { + return nil, nil +} + type stubUsageLogRepo struct { userLogs map[int64][]service.UsageLog } @@ -1529,11 +1605,15 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { +func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { +func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) { return nil, errors.New("not implemented") } @@ -1606,11 +1686,11 @@ func (r *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID i return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { +func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) { return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { +func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { return nil, errors.New("not implemented") } diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index d2d8ed40..a8034e98 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -51,6 +51,9 @@ func ProvideRouter( if err := r.SetTrustedProxies(nil); err != nil { log.Printf("Failed to disable trusted proxies: %v", err) } + if cfg.Server.Mode == "release" { + log.Printf("Warning: server.trusted_proxies is empty in release mode; client IP trust chain is disabled") + } } return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient) diff --git a/backend/internal/server/middleware/admin_auth.go b/backend/internal/server/middleware/admin_auth.go index 8f30107c..6f294ff0 100644 --- a/backend/internal/server/middleware/admin_auth.go +++ b/backend/internal/server/middleware/admin_auth.go @@ -58,8 +58,13 @@ func adminAuth( authHeader := c.GetHeader("Authorization") if authHeader != "" { parts := strings.SplitN(authHeader, " ", 2) - if len(parts) == 2 && parts[0] == "Bearer" { - if !validateJWTForAdmin(c, parts[1], authService, userService) { + if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") { + token := strings.TrimSpace(parts[1]) + if token == "" { + AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required") + return + } + if !validateJWTForAdmin(c, token, authService, userService) { return } c.Next() @@ -176,6 +181,12 @@ func validateJWTForAdmin( return false } + // 校验 TokenVersion,确保管理员改密后旧 token 失效 + if claims.TokenVersion != user.TokenVersion { + AbortWithError(c, 401, "TOKEN_REVOKED", "Token has been revoked (password changed)") + return false + } + // 检查管理员权限 if !user.IsAdmin() { AbortWithError(c, 403, "FORBIDDEN", "Admin access required") diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go new file mode 100644 index 00000000..033a5b77 --- /dev/null +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -0,0 +1,198 @@ +//go:build unit + +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}} + authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + + admin := &service.User{ + ID: 1, + Email: "admin@example.com", + Role: service.RoleAdmin, + Status: service.StatusActive, + TokenVersion: 2, + Concurrency: 1, + } + + userRepo := &stubUserRepo{ + getByID: func(ctx context.Context, id int64) (*service.User, error) { + if id != admin.ID { + return nil, service.ErrUserNotFound + } + clone := *admin + return &clone, nil + }, + } + userService := service.NewUserService(userRepo, nil, nil) + + router := gin.New() + router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil))) + router.GET("/t", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + t.Run("token_version_mismatch_rejected", func(t *testing.T) { + token, err := authService.GenerateToken(&service.User{ + ID: admin.ID, + Email: admin.Email, + Role: admin.Role, + TokenVersion: admin.TokenVersion - 1, + }) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + require.Contains(t, w.Body.String(), "TOKEN_REVOKED") + }) + + t.Run("token_version_match_allows", func(t *testing.T) { + token, err := authService.GenerateToken(&service.User{ + ID: admin.ID, + Email: admin.Email, + Role: admin.Role, + TokenVersion: admin.TokenVersion, + }) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("websocket_token_version_mismatch_rejected", func(t *testing.T) { + token, err := authService.GenerateToken(&service.User{ + ID: admin.ID, + Email: admin.Email, + Role: admin.Role, + TokenVersion: admin.TokenVersion - 1, + }) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + require.Contains(t, w.Body.String(), "TOKEN_REVOKED") + }) + + t.Run("websocket_token_version_match_allows", func(t *testing.T) { + token, err := authService.GenerateToken(&service.User{ + ID: admin.ID, + Email: admin.Email, + Role: admin.Role, + TokenVersion: admin.TokenVersion, + }) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + }) +} + +type stubUserRepo struct { + getByID func(ctx context.Context, id int64) (*service.User, error) +} + +func (s *stubUserRepo) Create(ctx context.Context, user *service.User) error { + panic("unexpected Create call") +} + +func (s *stubUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) { + if s.getByID == nil { + panic("GetByID not stubbed") + } + return s.getByID(ctx, id) +} + +func (s *stubUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) { + panic("unexpected GetByEmail call") +} + +func (s *stubUserRepo) GetFirstAdmin(ctx context.Context) (*service.User, error) { + panic("unexpected GetFirstAdmin call") +} + +func (s *stubUserRepo) Update(ctx context.Context, user *service.User) error { + panic("unexpected Update call") +} + +func (s *stubUserRepo) Delete(ctx context.Context, id int64) error { + panic("unexpected Delete call") +} + +func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error { + panic("unexpected UpdateBalance call") +} + +func (s *stubUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error { + panic("unexpected DeductBalance call") +} + +func (s *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error { + panic("unexpected UpdateConcurrency call") +} + +func (s *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { + panic("unexpected ExistsByEmail call") +} + +func (s *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { + panic("unexpected RemoveGroupFromAllowedGroups call") +} + +func (s *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { + panic("unexpected AddGroupToAllowedGroups call") +} + +func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { + panic("unexpected UpdateTotpSecret call") +} + +func (s *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error { + panic("unexpected EnableTotp call") +} + +func (s *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error { + panic("unexpected DisableTotp call") +} diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index 2f739357..972c1eaf 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -3,7 +3,6 @@ package middleware import ( "context" "errors" - "log" "strings" "github.com/Wei-Shaw/sub2api/internal/config" @@ -20,8 +19,16 @@ func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionS } // apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证) +// +// 中间件职责分为两层: +// - 鉴权(Authentication):验证 Key 有效性、用户状态、IP 限制 —— 始终执行 +// - 计费执行(Billing Enforcement):过期/配额/订阅/余额检查 —— skipBilling 时整块跳过 +// +// /v1/usage 端点只需鉴权,不需要计费执行(允许过期/配额耗尽的 Key 查询自身用量)。 func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc { return func(c *gin.Context) { + // ── 1. 提取 API Key ────────────────────────────────────────── + queryKey := strings.TrimSpace(c.Query("key")) queryApiKey := strings.TrimSpace(c.Query("api_key")) if queryKey != "" || queryApiKey != "" { @@ -36,8 +43,8 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti if authHeader != "" { // 验证Bearer scheme parts := strings.SplitN(authHeader, " ", 2) - if len(parts) == 2 && parts[0] == "Bearer" { - apiKeyString = parts[1] + if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") { + apiKeyString = strings.TrimSpace(parts[1]) } } @@ -57,7 +64,8 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti return } - // 从数据库验证API key + // ── 2. 验证 Key 存在 ───────────────────────────────────────── + apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString) if err != nil { if errors.Is(err, service.ErrAPIKeyNotFound) { @@ -68,37 +76,21 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti return } - // 检查API key是否激活 - if !apiKey.IsActive() { - // Provide more specific error message based on status - switch apiKey.Status { - case service.StatusAPIKeyQuotaExhausted: - AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完") - case service.StatusAPIKeyExpired: - AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期") - default: - AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled") - } - return - } + // ── 3. 基础鉴权(始终执行) ───────────────────────────────── - // 检查API Key是否过期(即使状态是active,也要检查时间) - if apiKey.IsExpired() { - AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期") - return - } - - // 检查API Key配额是否耗尽 - if apiKey.IsQuotaExhausted() { - AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完") + // disabled / 未知状态 → 无条件拦截(expired 和 quota_exhausted 留给计费阶段) + if !apiKey.IsActive() && + apiKey.Status != service.StatusAPIKeyExpired && + apiKey.Status != service.StatusAPIKeyQuotaExhausted { + AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled") return } // 检查 IP 限制(白名单/黑名单) // 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制 if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 { - clientIP := ip.GetClientIP(c) - allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist) + clientIP := ip.GetTrustedClientIP(c) + allowed, _ := ip.CheckIPRestrictionWithCompiledRules(clientIP, apiKey.CompiledIPWhitelist, apiKey.CompiledIPBlacklist) if !allowed { AbortWithError(c, 403, "ACCESS_DENIED", "Access denied") return @@ -117,8 +109,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti return } + // ── 4. SimpleMode → early return ───────────────────────────── + if cfg.RunMode == config.RunModeSimple { - // 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文 c.Set(string(ContextKeyAPIKey), apiKey) c.Set(string(ContextKeyUser), AuthSubject{ UserID: apiKey.User.ID, @@ -126,58 +119,94 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti }) c.Set(string(ContextKeyUserRole), apiKey.User.Role) setGroupContext(c, apiKey.Group) + _ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID) c.Next() return } - // 判断计费方式:订阅模式 vs 余额模式 + // ── 5. 加载订阅(订阅模式时始终加载) ─────────────────────── + + // skipBilling: /v1/usage 只需鉴权,跳过所有计费执行 + skipBilling := c.Request.URL.Path == "/v1/usage" + + var subscription *service.UserSubscription isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType() if isSubscriptionType && subscriptionService != nil { - // 订阅模式:验证订阅 - subscription, err := subscriptionService.GetActiveSubscription( + sub, subErr := subscriptionService.GetActiveSubscription( c.Request.Context(), apiKey.User.ID, apiKey.Group.ID, ) - if err != nil { - AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group") - return - } - - // 验证订阅状态(是否过期、暂停等) - if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil { - AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error()) - return - } - - // 激活滑动窗口(首次使用时) - if err := subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription); err != nil { - log.Printf("Failed to activate subscription windows: %v", err) - } - - // 检查并重置过期窗口 - if err := subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription); err != nil { - log.Printf("Failed to reset subscription windows: %v", err) - } - - // 预检查用量限制(使用0作为额外费用进行预检查) - if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil { - AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error()) - return - } - - // 将订阅信息存入上下文 - c.Set(string(ContextKeySubscription), subscription) - } else { - // 余额模式:检查用户余额 - if apiKey.User.Balance <= 0 { - AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance") - return + if subErr != nil { + if !skipBilling { + AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group") + return + } + // skipBilling: 订阅不存在也放行,handler 会返回可用的数据 + } else { + subscription = sub } } - // 将API key和用户信息存入上下文 + // ── 6. 计费执行(skipBilling 时整块跳过) ──────────────────── + + if !skipBilling { + // Key 状态检查 + switch apiKey.Status { + case service.StatusAPIKeyQuotaExhausted: + AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完") + return + case service.StatusAPIKeyExpired: + AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期") + return + } + + // 运行时过期/配额检查(即使状态是 active,也要检查时间和用量) + if apiKey.IsExpired() { + AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期") + return + } + if apiKey.IsQuotaExhausted() { + AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完") + return + } + + // 订阅模式:验证订阅限额 + if subscription != nil { + needsMaintenance, validateErr := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group) + if validateErr != nil { + code := "SUBSCRIPTION_INVALID" + status := 403 + if errors.Is(validateErr, service.ErrDailyLimitExceeded) || + errors.Is(validateErr, service.ErrWeeklyLimitExceeded) || + errors.Is(validateErr, service.ErrMonthlyLimitExceeded) { + code = "USAGE_LIMIT_EXCEEDED" + status = 429 + } + AbortWithError(c, status, code, validateErr.Error()) + return + } + + // 窗口维护异步化(不阻塞请求) + if needsMaintenance { + maintenanceCopy := *subscription + subscriptionService.DoWindowMaintenance(&maintenanceCopy) + } + } else { + // 非订阅模式 或 订阅模式但 subscriptionService 未注入:回退到余额检查 + if apiKey.User.Balance <= 0 { + AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance") + return + } + } + } + + // ── 7. 设置上下文 → Next ───────────────────────────────────── + + if subscription != nil { + c.Set(string(ContextKeySubscription), subscription) + } c.Set(string(ContextKeyAPIKey), apiKey) c.Set(string(ContextKeyUser), AuthSubject{ UserID: apiKey.User.ID, @@ -185,6 +214,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti }) c.Set(string(ContextKeyUserRole), apiKey.User.Role) setGroupContext(c, apiKey.Group) + _ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID) c.Next() } diff --git a/backend/internal/server/middleware/api_key_auth_google.go b/backend/internal/server/middleware/api_key_auth_google.go index 38fbe38b..84d93edc 100644 --- a/backend/internal/server/middleware/api_key_auth_google.go +++ b/backend/internal/server/middleware/api_key_auth_google.go @@ -64,6 +64,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs }) c.Set(string(ContextKeyUserRole), apiKey.User.Role) setGroupContext(c, apiKey.Group) + _ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID) c.Next() return } @@ -79,17 +80,25 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs abortWithGoogleError(c, 403, "No active subscription found for this group") return } - if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil { - abortWithGoogleError(c, 403, err.Error()) - return - } - _ = subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription) - _ = subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription) - if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil { - abortWithGoogleError(c, 429, err.Error()) + + needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group) + if err != nil { + status := 403 + if errors.Is(err, service.ErrDailyLimitExceeded) || + errors.Is(err, service.ErrWeeklyLimitExceeded) || + errors.Is(err, service.ErrMonthlyLimitExceeded) { + status = 429 + } + abortWithGoogleError(c, status, err.Error()) return } + c.Set(string(ContextKeySubscription), subscription) + + if needsMaintenance { + maintenanceCopy := *subscription + subscriptionService.DoWindowMaintenance(&maintenanceCopy) + } } else { if apiKey.User.Balance <= 0 { abortWithGoogleError(c, 403, "Insufficient account balance") @@ -104,6 +113,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs }) c.Set(string(ContextKeyUserRole), apiKey.User.Role) setGroupContext(c, apiKey.Group) + _ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID) c.Next() } } diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index 38b93cb2..49db5f19 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" @@ -18,7 +19,17 @@ import ( ) type fakeAPIKeyRepo struct { - getByKey func(ctx context.Context, key string) (*service.APIKey, error) + getByKey func(ctx context.Context, key string) (*service.APIKey, error) + updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error +} + +type fakeGoogleSubscriptionRepo struct { + getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) + updateStatus func(ctx context.Context, subscriptionID int64, status string) error + activateWindow func(ctx context.Context, id int64, start time.Time) error + resetDaily func(ctx context.Context, id int64, start time.Time) error + resetWeekly func(ctx context.Context, id int64, start time.Time) error + resetMonthly func(ctx context.Context, id int64, start time.Time) error } func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error { @@ -45,7 +56,7 @@ func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error { func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error { return errors.New("not implemented") } -func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { +func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { @@ -78,6 +89,100 @@ func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([ func (f fakeAPIKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { return 0, errors.New("not implemented") } +func (f fakeAPIKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { + if f.updateLastUsed != nil { + return f.updateLastUsed(ctx, id, usedAt) + } + return nil +} +func (f fakeAPIKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error { + return nil +} +func (f fakeAPIKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error { + return nil +} +func (f fakeAPIKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) { + return &service.APIKeyRateLimitData{}, nil +} + +func (f fakeGoogleSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) { + return nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + return nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + if f.getActive != nil { + return f.getActive(ctx, userID, groupID) + } + return nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) Delete(ctx context.Context, id int64) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { + return nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { + return nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { + return false, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error { + if f.updateStatus != nil { + return f.updateStatus(ctx, subscriptionID, status) + } + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error { + if f.activateWindow != nil { + return f.activateWindow(ctx, id, start) + } + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, start time.Time) error { + if f.resetDaily != nil { + return f.resetDaily(ctx, id, start) + } + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, start time.Time) error { + if f.resetWeekly != nil { + return f.resetWeekly(ctx, id, start) + } + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, start time.Time) error { + if f.resetMonthly != nil { + return f.resetMonthly(ctx, id, start) + } + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { + return 0, errors.New("not implemented") +} type googleErrorResponse struct { Error struct { @@ -356,3 +461,226 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) { require.Equal(t, "Insufficient account balance", resp.Error.Message) require.Equal(t, "PERMISSION_DENIED", resp.Error.Status) } + +func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedOnSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 11, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 201, + UserID: user.ID, + Key: "google-touch-ok", + Status: service.StatusActive, + User: user, + } + + var touchedID int64 + var touchedAt time.Time + r := gin.New() + apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + touchedID = id + touchedAt = usedAt + return nil + }, + }) + cfg := &config.Config{RunMode: config.RunModeSimple} + r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)) + r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) + + req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) + req.Header.Set("x-goog-api-key", apiKey.Key) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, apiKey.ID, touchedID) + require.False(t, touchedAt.IsZero()) +} + +func TestApiKeyAuthWithSubscriptionGoogle_TouchFailureDoesNotBlock(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 12, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 202, + UserID: user.ID, + Key: "google-touch-fail", + Status: service.StatusActive, + User: user, + } + + touchCalls := 0 + r := gin.New() + apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + touchCalls++ + return errors.New("write failed") + }, + }) + cfg := &config.Config{RunMode: config.RunModeSimple} + r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)) + r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) + + req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) + req.Header.Set("x-goog-api-key", apiKey.Key) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, 1, touchCalls) +} + +func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedInStandardMode(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 13, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 203, + UserID: user.ID, + Key: "google-touch-standard", + Status: service.StatusActive, + User: user, + } + + touchCalls := 0 + r := gin.New() + apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + touchCalls++ + return nil + }, + }) + cfg := &config.Config{RunMode: config.RunModeStandard} + r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)) + r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) + + req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) + req.Header.Set("Authorization", "Bearer "+apiKey.Key) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, 1, touchCalls) +} + +func TestApiKeyAuthWithSubscriptionGoogle_SubscriptionLimitExceededReturns429(t *testing.T) { + gin.SetMode(gin.TestMode) + + limit := 1.0 + group := &service.Group{ + ID: 77, + Name: "gemini-sub", + Status: service.StatusActive, + Platform: service.PlatformGemini, + Hydrated: true, + SubscriptionType: service.SubscriptionTypeSubscription, + DailyLimitUSD: &limit, + } + user := &service.User{ + ID: 999, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 501, + UserID: user.ID, + Key: "google-sub-limit", + Status: service.StatusActive, + User: user, + Group: group, + } + apiKey.GroupID = &group.ID + + apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + }) + + now := time.Now() + sub := &service.UserSubscription{ + ID: 601, + UserID: user.ID, + GroupID: group.ID, + Status: service.SubscriptionStatusActive, + ExpiresAt: now.Add(24 * time.Hour), + DailyWindowStart: &now, + DailyUsageUSD: 10, + } + subscriptionService := service.NewSubscriptionService(nil, fakeGoogleSubscriptionRepo{ + getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + if userID != user.ID || groupID != group.ID { + return nil, service.ErrSubscriptionNotFound + } + clone := *sub + return &clone, nil + }, + updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil }, + activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetDaily: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil }, + }, nil, nil, &config.Config{RunMode: config.RunModeStandard}) + + r := gin.New() + r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, &config.Config{RunMode: config.RunModeStandard})) + r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) + + req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) + req.Header.Set("x-goog-api-key", apiKey.Key) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusTooManyRequests, rec.Code) + var resp googleErrorResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, http.StatusTooManyRequests, resp.Error.Code) + require.Equal(t, "RESOURCE_EXHAUSTED", resp.Error.Status) + require.Contains(t, resp.Error.Message, "daily usage limit exceeded") +} diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 9d514818..22befa2a 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -57,10 +57,41 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { }, } - t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { - cfg := &config.Config{RunMode: config.RunModeSimple} + t.Run("standard_mode_needs_maintenance_does_not_block_request", func(t *testing.T) { + cfg := &config.Config{RunMode: config.RunModeStandard} + cfg.SubscriptionMaintenance.WorkerCount = 1 + cfg.SubscriptionMaintenance.QueueSize = 1 + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) - subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil) + + past := time.Now().Add(-48 * time.Hour) + sub := &service.UserSubscription{ + ID: 55, + UserID: user.ID, + GroupID: group.ID, + Status: service.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + DailyWindowStart: &past, + DailyUsageUSD: 0, + } + maintenanceCalled := make(chan struct{}, 1) + subscriptionRepo := &stubUserSubscriptionRepo{ + getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + clone := *sub + return &clone, nil + }, + updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil }, + activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetDaily: func(ctx context.Context, id int64, start time.Time) error { + maintenanceCalled <- struct{}{} + return nil + }, + resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil }, + } + subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, nil, cfg) + t.Cleanup(subscriptionService.Stop) + router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) w := httptest.NewRecorder() @@ -68,6 +99,40 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { req.Header.Set("x-api-key", apiKey.Key) router.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + select { + case <-maintenanceCalled: + // ok + case <-time.After(time.Second): + t.Fatalf("expected maintenance to be scheduled") + } + }) + + t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, nil, cfg) + router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("simple_mode_accepts_lowercase_bearer", func(t *testing.T) { + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, nil, cfg) + router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("Authorization", "bearer "+apiKey.Key) + router.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) }) @@ -99,7 +164,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil }, resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil }, } - subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil) + subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, nil, cfg) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) w := httptest.NewRecorder() @@ -235,6 +300,198 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) { require.Equal(t, http.StatusOK, w.Code) } +func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 7, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 100, + UserID: user.ID, + Key: "test-key", + Status: service.StatusActive, + User: user, + IPWhitelist: []string{"1.2.3.4"}, + } + + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + } + + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + router := gin.New() + require.NoError(t, router.SetTrustedProxies(nil)) + router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) + router.GET("/t", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.RemoteAddr = "9.9.9.9:12345" + req.Header.Set("x-api-key", apiKey.Key) + req.Header.Set("X-Forwarded-For", "1.2.3.4") + req.Header.Set("X-Real-IP", "1.2.3.4") + req.Header.Set("CF-Connecting-IP", "1.2.3.4") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusForbidden, w.Code) + require.Contains(t, w.Body.String(), "ACCESS_DENIED") +} + +func TestAPIKeyAuthTouchesLastUsedOnSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 7, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 100, + UserID: user.ID, + Key: "touch-ok", + Status: service.StatusActive, + User: user, + } + + var touchedID int64 + var touchedAt time.Time + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + touchedID = id + touchedAt = usedAt + return nil + }, + } + + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + router := newAuthTestRouter(apiKeyService, nil, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, apiKey.ID, touchedID) + require.False(t, touchedAt.IsZero(), "expected touch timestamp") +} + +func TestAPIKeyAuthTouchLastUsedFailureDoesNotBlock(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 8, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 101, + UserID: user.ID, + Key: "touch-fail", + Status: service.StatusActive, + User: user, + } + + touchCalls := 0 + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + touchCalls++ + return errors.New("db unavailable") + }, + } + + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + router := newAuthTestRouter(apiKeyService, nil, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, "touch failure should not block request") + require.Equal(t, 1, touchCalls) +} + +func TestAPIKeyAuthTouchesLastUsedInStandardMode(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 9, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 102, + UserID: user.ID, + Key: "touch-standard", + Status: service.StatusActive, + User: user, + } + + touchCalls := 0 + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + touchCalls++ + return nil + }, + } + + cfg := &config.Config{RunMode: config.RunModeStandard} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + router := newAuthTestRouter(apiKeyService, nil, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, 1, touchCalls) +} + func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine { router := gin.New() router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg))) @@ -245,7 +502,8 @@ func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService } type stubApiKeyRepo struct { - getByKey func(ctx context.Context, key string) (*service.APIKey, error) + getByKey func(ctx context.Context, key string) (*service.APIKey, error) + updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error } func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.APIKey) error { @@ -279,7 +537,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error { return errors.New("not implemented") } -func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { +func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } @@ -323,6 +581,23 @@ func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amoun return 0, errors.New("not implemented") } +func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { + if r.updateLastUsed != nil { + return r.updateLastUsed(ctx, id, usedAt) + } + return nil +} + +func (r *stubApiKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error { + return nil +} +func (r *stubApiKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error { + return nil +} +func (r *stubApiKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) { + return nil, nil +} + type stubUserSubscriptionRepo struct { getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) updateStatus func(ctx context.Context, subscriptionID int64, status string) error diff --git a/backend/internal/server/middleware/client_request_id.go b/backend/internal/server/middleware/client_request_id.go index d22b6cc5..6838d6af 100644 --- a/backend/internal/server/middleware/client_request_id.go +++ b/backend/internal/server/middleware/client_request_id.go @@ -2,10 +2,13 @@ package middleware import ( "context" + "strings" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/gin-gonic/gin" "github.com/google/uuid" + "go.uber.org/zap" ) // ClientRequestID ensures every request has a unique client_request_id in request.Context(). @@ -24,7 +27,10 @@ func ClientRequestID() gin.HandlerFunc { } id := uuid.New().String() - c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id)) + ctx := context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id) + requestLogger := logger.FromContext(ctx).With(zap.String("client_request_id", strings.TrimSpace(id))) + ctx = logger.IntoContext(ctx, requestLogger) + c.Request = c.Request.WithContext(ctx) c.Next() } } diff --git a/backend/internal/server/middleware/cors.go b/backend/internal/server/middleware/cors.go index 7d82f183..03d5d025 100644 --- a/backend/internal/server/middleware/cors.go +++ b/backend/internal/server/middleware/cors.go @@ -50,6 +50,19 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc { } allowedSet[origin] = struct{}{} } + allowHeaders := []string{ + "Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", + "accept", "origin", "Cache-Control", "X-Requested-With", "X-API-Key", + } + // OpenAI Node SDK 会发送 x-stainless-* 请求头,需在 CORS 中显式放行。 + openAIProperties := []string{ + "lang", "package-version", "os", "arch", "retry-count", "runtime", + "runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval", "timeout", + } + for _, prop := range openAIProperties { + allowHeaders = append(allowHeaders, "x-stainless-"+prop) + } + allowHeadersValue := strings.Join(allowHeaders, ", ") return func(c *gin.Context) { origin := strings.TrimSpace(c.GetHeader("Origin")) @@ -68,11 +81,11 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc { if allowCredentials { c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") } + c.Writer.Header().Set("Access-Control-Allow-Headers", allowHeadersValue) + c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH") + c.Writer.Header().Set("Access-Control-Expose-Headers", "ETag") + c.Writer.Header().Set("Access-Control-Max-Age", "86400") } - - c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key") - c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH") - // 处理预检请求 if c.Request.Method == http.MethodOptions { if originAllowed { diff --git a/backend/internal/server/middleware/cors_test.go b/backend/internal/server/middleware/cors_test.go new file mode 100644 index 00000000..6d0bea36 --- /dev/null +++ b/backend/internal/server/middleware/cors_test.go @@ -0,0 +1,308 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func init() { + // cors_test 与 security_headers_test 在同一个包,但 init 是幂等的 + gin.SetMode(gin.TestMode) +} + +// --- Task 8.2: 验证 CORS 条件化头部 --- + +func TestCORS_DisallowedOrigin_NoAllowHeaders(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + tests := []struct { + name string + method string + origin string + }{ + { + name: "preflight_disallowed_origin", + method: http.MethodOptions, + origin: "https://evil.example.com", + }, + { + name: "get_disallowed_origin", + method: http.MethodGet, + origin: "https://evil.example.com", + }, + { + name: "post_disallowed_origin", + method: http.MethodPost, + origin: "https://attacker.example.com", + }, + { + name: "preflight_no_origin", + method: http.MethodOptions, + origin: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(tt.method, "/", nil) + if tt.origin != "" { + c.Request.Header.Set("Origin", tt.origin) + } + + middleware(c) + + // 不应设置 Allow-Headers、Allow-Methods 和 Max-Age + assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"), + "不允许的 origin 不应收到 Allow-Headers") + assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods"), + "不允许的 origin 不应收到 Allow-Methods") + assert.Empty(t, w.Header().Get("Access-Control-Max-Age"), + "不允许的 origin 不应收到 Max-Age") + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"), + "不允许的 origin 不应收到 Allow-Origin") + }) + } +} + +func TestCORS_AllowedOrigin_HasAllowHeaders(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + tests := []struct { + name string + method string + }{ + {name: "preflight_OPTIONS", method: http.MethodOptions}, + {name: "normal_GET", method: http.MethodGet}, + {name: "normal_POST", method: http.MethodPost}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(tt.method, "/", nil) + c.Request.Header.Set("Origin", "https://allowed.example.com") + + middleware(c) + + // 应设置 Allow-Headers、Allow-Methods 和 Max-Age + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"), + "允许的 origin 应收到 Allow-Headers") + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"), + "允许的 origin 应收到 Allow-Methods") + assert.Equal(t, "86400", w.Header().Get("Access-Control-Max-Age"), + "允许的 origin 应收到 Max-Age=86400") + assert.Equal(t, "https://allowed.example.com", w.Header().Get("Access-Control-Allow-Origin"), + "允许的 origin 应收到 Allow-Origin") + }) + } +} + +func TestCORS_PreflightDisallowedOrigin_ReturnsForbidden(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodOptions, "/", nil) + c.Request.Header.Set("Origin", "https://evil.example.com") + + middleware(c) + + assert.Equal(t, http.StatusForbidden, w.Code, + "不允许的 origin 的 preflight 请求应返回 403") +} + +func TestCORS_PreflightAllowedOrigin_ReturnsNoContent(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodOptions, "/", nil) + c.Request.Header.Set("Origin", "https://allowed.example.com") + + middleware(c) + + assert.Equal(t, http.StatusNoContent, w.Code, + "允许的 origin 的 preflight 请求应返回 204") +} + +func TestCORS_WildcardOrigin_AllowsAny(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"*"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://any-origin.example.com") + + middleware(c) + + assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"), + "通配符配置应返回 *") + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"), + "通配符 origin 应设置 Allow-Headers") + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"), + "通配符 origin 应设置 Allow-Methods") +} + +func TestCORS_AllowCredentials_SetCorrectly(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: true, + } + middleware := CORS(cfg) + + t.Run("allowed_origin_gets_credentials", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://allowed.example.com") + + middleware(c) + + assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials"), + "允许的 origin 且开启 credentials 应设置 Allow-Credentials") + }) + + t.Run("disallowed_origin_no_credentials", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://evil.example.com") + + middleware(c) + + assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"), + "不允许的 origin 不应收到 Allow-Credentials") + }) +} + +func TestCORS_WildcardWithCredentials_DisablesCredentials(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"*"}, + AllowCredentials: true, + } + middleware := CORS(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://any.example.com") + + middleware(c) + + // 通配符 + credentials 不兼容,credentials 应被禁用 + assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"), + "通配符 origin 应禁用 Allow-Credentials") +} + +func TestCORS_MultipleAllowedOrigins(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{ + "https://app1.example.com", + "https://app2.example.com", + }, + AllowCredentials: false, + } + middleware := CORS(cfg) + + t.Run("first_origin_allowed", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://app1.example.com") + + middleware(c) + + assert.Equal(t, "https://app1.example.com", w.Header().Get("Access-Control-Allow-Origin")) + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers")) + }) + + t.Run("second_origin_allowed", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://app2.example.com") + + middleware(c) + + assert.Equal(t, "https://app2.example.com", w.Header().Get("Access-Control-Allow-Origin")) + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers")) + }) + + t.Run("unlisted_origin_rejected", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://app3.example.com") + + middleware(c) + + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers")) + }) +} + +func TestCORS_VaryHeader_SetForSpecificOrigin(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://allowed.example.com") + + middleware(c) + + assert.Contains(t, w.Header().Values("Vary"), "Origin", + "非通配符允许的 origin 应设置 Vary: Origin") +} + +func TestNormalizeOrigins(t *testing.T) { + tests := []struct { + name string + input []string + expect []string + }{ + {name: "nil_input", input: nil, expect: nil}, + {name: "empty_input", input: []string{}, expect: nil}, + {name: "trims_whitespace", input: []string{" https://a.com ", " https://b.com"}, expect: []string{"https://a.com", "https://b.com"}}, + {name: "removes_empty_strings", input: []string{"", " ", "https://a.com"}, expect: []string{"https://a.com"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := normalizeOrigins(tt.input) + assert.Equal(t, tt.expect, result) + }) + } +} diff --git a/backend/internal/server/middleware/jwt_auth.go b/backend/internal/server/middleware/jwt_auth.go index 9a89aab7..4aceb355 100644 --- a/backend/internal/server/middleware/jwt_auth.go +++ b/backend/internal/server/middleware/jwt_auth.go @@ -26,12 +26,12 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService) // 验证Bearer scheme parts := strings.SplitN(authHeader, " ", 2) - if len(parts) != 2 || parts[0] != "Bearer" { + if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { AbortWithError(c, 401, "INVALID_AUTH_HEADER", "Authorization header format must be 'Bearer {token}'") return } - tokenString := parts[1] + tokenString := strings.TrimSpace(parts[1]) if tokenString == "" { AbortWithError(c, 401, "EMPTY_TOKEN", "Token cannot be empty") return diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go new file mode 100644 index 00000000..f8839cfe --- /dev/null +++ b/backend/internal/server/middleware/jwt_auth_test.go @@ -0,0 +1,256 @@ +//go:build unit + +package middleware + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// stubJWTUserRepo 实现 UserRepository 的最小子集,仅支持 GetByID。 +type stubJWTUserRepo struct { + service.UserRepository + users map[int64]*service.User +} + +func (r *stubJWTUserRepo) GetByID(_ context.Context, id int64) (*service.User, error) { + u, ok := r.users[id] + if !ok { + return nil, errors.New("user not found") + } + return u, nil +} + +// newJWTTestEnv 创建 JWT 认证中间件测试环境。 +// 返回 gin.Engine(已注册 JWT 中间件)和 AuthService(用于生成 Token)。 +func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthService) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.JWT.Secret = "test-jwt-secret-32bytes-long!!!" + cfg.JWT.AccessTokenExpireMinutes = 60 + + userRepo := &stubJWTUserRepo{users: users} + authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + userSvc := service.NewUserService(userRepo, nil, nil) + mw := NewJWTAuthMiddleware(authSvc, userSvc) + + r := gin.New() + r.Use(gin.HandlerFunc(mw)) + r.GET("/protected", func(c *gin.Context) { + subject, _ := GetAuthSubjectFromContext(c) + role, _ := GetUserRoleFromContext(c) + c.JSON(http.StatusOK, gin.H{ + "user_id": subject.UserID, + "role": role, + }) + }) + return r, authSvc +} + +func TestJWTAuth_ValidToken(t *testing.T) { + user := &service.User{ + ID: 1, + Email: "test@example.com", + Role: "user", + Status: service.StatusActive, + Concurrency: 5, + TokenVersion: 1, + } + router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user}) + + token, err := authSvc.GenerateToken(user) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + + var body map[string]any + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, float64(1), body["user_id"]) + require.Equal(t, "user", body["role"]) +} + +func TestJWTAuth_ValidToken_LowercaseBearer(t *testing.T) { + user := &service.User{ + ID: 1, + Email: "test@example.com", + Role: "user", + Status: service.StatusActive, + Concurrency: 5, + TokenVersion: 1, + } + router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user}) + + token, err := authSvc.GenerateToken(user) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) +} + +func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) { + router, _ := newJWTTestEnv(nil) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "UNAUTHORIZED", body.Code) +} + +func TestJWTAuth_InvalidHeaderFormat(t *testing.T) { + tests := []struct { + name string + header string + }{ + {"无Bearer前缀", "Token abc123"}, + {"缺少空格分隔", "Bearerabc123"}, + {"仅有单词", "abc123"}, + } + router, _ := newJWTTestEnv(nil) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", tt.header) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "INVALID_AUTH_HEADER", body.Code) + }) + } +} + +func TestJWTAuth_EmptyToken(t *testing.T) { + router, _ := newJWTTestEnv(nil) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer ") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "EMPTY_TOKEN", body.Code) +} + +func TestJWTAuth_TamperedToken(t *testing.T) { + router, _ := newJWTTestEnv(nil) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer eyJhbGciOiJIUzI1NiJ9.eyJ1c2VyX2lkIjoxfQ.invalid_signature") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "INVALID_TOKEN", body.Code) +} + +func TestJWTAuth_UserNotFound(t *testing.T) { + // 使用 user ID=1 的 token,但 repo 中没有该用户 + fakeUser := &service.User{ + ID: 999, + Email: "ghost@example.com", + Role: "user", + Status: service.StatusActive, + TokenVersion: 1, + } + // 创建环境时不注入此用户,这样 GetByID 会失败 + router, authSvc := newJWTTestEnv(map[int64]*service.User{}) + + token, err := authSvc.GenerateToken(fakeUser) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "USER_NOT_FOUND", body.Code) +} + +func TestJWTAuth_UserInactive(t *testing.T) { + user := &service.User{ + ID: 1, + Email: "disabled@example.com", + Role: "user", + Status: service.StatusDisabled, + TokenVersion: 1, + } + router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user}) + + token, err := authSvc.GenerateToken(user) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "USER_INACTIVE", body.Code) +} + +func TestJWTAuth_TokenVersionMismatch(t *testing.T) { + // Token 生成时 TokenVersion=1,但数据库中用户已更新为 TokenVersion=2(密码修改) + userForToken := &service.User{ + ID: 1, + Email: "test@example.com", + Role: "user", + Status: service.StatusActive, + TokenVersion: 1, + } + userInDB := &service.User{ + ID: 1, + Email: "test@example.com", + Role: "user", + Status: service.StatusActive, + TokenVersion: 2, // 密码修改后版本递增 + } + router, authSvc := newJWTTestEnv(map[int64]*service.User{1: userInDB}) + + token, err := authSvc.GenerateToken(userForToken) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "TOKEN_REVOKED", body.Code) +} diff --git a/backend/internal/server/middleware/logger.go b/backend/internal/server/middleware/logger.go index 842efda9..b14a3a21 100644 --- a/backend/internal/server/middleware/logger.go +++ b/backend/internal/server/middleware/logger.go @@ -1,10 +1,12 @@ package middleware import ( - "log" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/gin-gonic/gin" + "go.uber.org/zap" ) // Logger 请求日志中间件 @@ -13,44 +15,52 @@ func Logger() gin.HandlerFunc { // 开始时间 startTime := time.Now() - // 处理请求 - c.Next() - - // 结束时间 - endTime := time.Now() - - // 执行时间 - latency := endTime.Sub(startTime) - - // 请求方法 - method := c.Request.Method - // 请求路径 path := c.Request.URL.Path - // 状态码 + // 处理请求 + c.Next() + + // 跳过健康检查等高频探针路径的日志 + if path == "/health" || path == "/setup/status" { + return + } + + endTime := time.Now() + latency := endTime.Sub(startTime) + + method := c.Request.Method statusCode := c.Writer.Status() - - // 客户端IP clientIP := c.ClientIP() - - // 协议版本 protocol := c.Request.Proto + accountID, hasAccountID := c.Request.Context().Value(ctxkey.AccountID).(int64) + platform, _ := c.Request.Context().Value(ctxkey.Platform).(string) + model, _ := c.Request.Context().Value(ctxkey.Model).(string) - // 日志格式: [时间] 状态码 | 延迟 | IP | 协议 | 方法 路径 - log.Printf("[GIN] %v | %3d | %13v | %15s | %-6s | %-7s %s", - endTime.Format("2006/01/02 - 15:04:05"), - statusCode, - latency, - clientIP, - protocol, - method, - path, - ) + fields := []zap.Field{ + zap.String("component", "http.access"), + zap.Int("status_code", statusCode), + zap.Int64("latency_ms", latency.Milliseconds()), + zap.String("client_ip", clientIP), + zap.String("protocol", protocol), + zap.String("method", method), + zap.String("path", path), + } + if hasAccountID && accountID > 0 { + fields = append(fields, zap.Int64("account_id", accountID)) + } + if platform != "" { + fields = append(fields, zap.String("platform", platform)) + } + if model != "" { + fields = append(fields, zap.String("model", model)) + } + + l := logger.FromContext(c.Request.Context()).With(fields...) + l.Info("http request completed", zap.Time("completed_at", endTime)) - // 如果有错误,额外记录错误信息 if len(c.Errors) > 0 { - log.Printf("[GIN] Errors: %v", c.Errors.String()) + l.Warn("http request contains gin errors", zap.String("errors", c.Errors.String())) } } } diff --git a/backend/internal/server/middleware/middleware.go b/backend/internal/server/middleware/middleware.go index 26572019..27985cf8 100644 --- a/backend/internal/server/middleware/middleware.go +++ b/backend/internal/server/middleware/middleware.go @@ -2,8 +2,11 @@ package middleware import ( "context" + "net/http" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" ) @@ -71,3 +74,48 @@ func AbortWithError(c *gin.Context, statusCode int, code, message string) { c.JSON(statusCode, NewErrorResponse(code, message)) c.Abort() } + +// ────────────────────────────────────────────────────────── +// RequireGroupAssignment — 未分组 Key 拦截中间件 +// ────────────────────────────────────────────────────────── + +// GatewayErrorWriter 定义网关错误响应格式(不同协议使用不同格式) +type GatewayErrorWriter func(c *gin.Context, status int, message string) + +// AnthropicErrorWriter 按 Anthropic API 规范输出错误 +func AnthropicErrorWriter(c *gin.Context, status int, message string) { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{"type": "permission_error", "message": message}, + }) +} + +// GoogleErrorWriter 按 Google API 规范输出错误 +func GoogleErrorWriter(c *gin.Context, status int, message string) { + c.JSON(status, gin.H{ + "error": gin.H{ + "code": status, + "message": message, + "status": googleapi.HTTPStatusToGoogleStatus(status), + }, + }) +} + +// RequireGroupAssignment 检查 API Key 是否已分配到分组, +// 如果未分组且系统设置不允许未分组 Key 调度则返回 403。 +func RequireGroupAssignment(settingService *service.SettingService, writeError GatewayErrorWriter) gin.HandlerFunc { + return func(c *gin.Context) { + apiKey, ok := GetAPIKeyFromContext(c) + if !ok || apiKey.GroupID != nil { + c.Next() + return + } + // 未分组 Key — 检查系统设置 + if settingService.IsUngroupedKeySchedulingAllowed(c.Request.Context()) { + c.Next() + return + } + writeError(c, http.StatusForbidden, "API Key is not assigned to any group and cannot be used. Please contact the administrator to assign it to a group.") + c.Abort() + } +} diff --git a/backend/internal/server/middleware/misc_coverage_test.go b/backend/internal/server/middleware/misc_coverage_test.go new file mode 100644 index 00000000..c0adfc4d --- /dev/null +++ b/backend/internal/server/middleware/misc_coverage_test.go @@ -0,0 +1,126 @@ +//go:build unit + +package middleware + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestClientRequestID_GeneratesWhenMissing(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(ClientRequestID()) + r.GET("/t", func(c *gin.Context) { + v := c.Request.Context().Value(ctxkey.ClientRequestID) + require.NotNil(t, v) + id, ok := v.(string) + require.True(t, ok) + require.NotEmpty(t, id) + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) +} + +func TestClientRequestID_PreservesExisting(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(ClientRequestID()) + r.GET("/t", func(c *gin.Context) { + id, ok := c.Request.Context().Value(ctxkey.ClientRequestID).(string) + require.True(t, ok) + require.Equal(t, "keep", id) + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req = req.WithContext(context.WithValue(req.Context(), ctxkey.ClientRequestID, "keep")) + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) +} + +func TestRequestBodyLimit_LimitsBody(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(RequestBodyLimit(4)) + r.POST("/t", func(c *gin.Context) { + _, err := io.ReadAll(c.Request.Body) + require.Error(t, err) + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/t", bytes.NewBufferString("12345")) + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) +} + +func TestForcePlatform_SetsContextAndGinValue(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(ForcePlatform("anthropic")) + r.GET("/t", func(c *gin.Context) { + require.True(t, HasForcePlatform(c)) + v, ok := GetForcePlatformFromContext(c) + require.True(t, ok) + require.Equal(t, "anthropic", v) + + ctxV := c.Request.Context().Value(ctxkey.ForcePlatform) + require.Equal(t, "anthropic", ctxV) + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) +} + +func TestAuthSubjectHelpers_RoundTrip(t *testing.T) { + c := &gin.Context{} + c.Set(string(ContextKeyUser), AuthSubject{UserID: 1, Concurrency: 2}) + c.Set(string(ContextKeyUserRole), "admin") + + sub, ok := GetAuthSubjectFromContext(c) + require.True(t, ok) + require.Equal(t, int64(1), sub.UserID) + require.Equal(t, 2, sub.Concurrency) + + role, ok := GetUserRoleFromContext(c) + require.True(t, ok) + require.Equal(t, "admin", role) +} + +func TestAPIKeyAndSubscriptionFromContext(t *testing.T) { + c := &gin.Context{} + + key := &service.APIKey{ID: 1} + c.Set(string(ContextKeyAPIKey), key) + gotKey, ok := GetAPIKeyFromContext(c) + require.True(t, ok) + require.Equal(t, int64(1), gotKey.ID) + + sub := &service.UserSubscription{ID: 2} + c.Set(string(ContextKeySubscription), sub) + gotSub, ok := GetSubscriptionFromContext(c) + require.True(t, ok) + require.Equal(t, int64(2), gotSub.ID) +} diff --git a/backend/internal/server/middleware/recovery_test.go b/backend/internal/server/middleware/recovery_test.go index 439f44cb..33e71d51 100644 --- a/backend/internal/server/middleware/recovery_test.go +++ b/backend/internal/server/middleware/recovery_test.go @@ -3,6 +3,7 @@ package middleware import ( + "bytes" "encoding/json" "net/http" "net/http/httptest" @@ -14,6 +15,34 @@ import ( "github.com/stretchr/testify/require" ) +func TestRecovery_PanicLogContainsInfo(t *testing.T) { + gin.SetMode(gin.TestMode) + + // 临时替换 DefaultErrorWriter 以捕获日志输出 + var buf bytes.Buffer + originalWriter := gin.DefaultErrorWriter + gin.DefaultErrorWriter = &buf + t.Cleanup(func() { + gin.DefaultErrorWriter = originalWriter + }) + + r := gin.New() + r.Use(Recovery()) + r.GET("/panic", func(c *gin.Context) { + panic("custom panic message for test") + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/panic", nil) + r.ServeHTTP(w, req) + + require.Equal(t, http.StatusInternalServerError, w.Code) + + logOutput := buf.String() + require.Contains(t, logOutput, "custom panic message for test", "日志应包含 panic 信息") + require.Contains(t, logOutput, "recovery_test.go", "日志应包含堆栈跟踪文件名") +} + func TestRecovery(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/server/middleware/request_access_logger_test.go b/backend/internal/server/middleware/request_access_logger_test.go new file mode 100644 index 00000000..fec3ed22 --- /dev/null +++ b/backend/internal/server/middleware/request_access_logger_test.go @@ -0,0 +1,228 @@ +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/gin-gonic/gin" +) + +type testLogSink struct { + mu sync.Mutex + events []*logger.LogEvent +} + +func (s *testLogSink) WriteLogEvent(event *logger.LogEvent) { + s.mu.Lock() + defer s.mu.Unlock() + s.events = append(s.events, event) +} + +func (s *testLogSink) list() []*logger.LogEvent { + s.mu.Lock() + defer s.mu.Unlock() + out := make([]*logger.LogEvent, len(s.events)) + copy(out, s.events) + return out +} + +func initMiddlewareTestLogger(t *testing.T) *testLogSink { + return initMiddlewareTestLoggerWithLevel(t, "debug") +} + +func initMiddlewareTestLoggerWithLevel(t *testing.T, level string) *testLogSink { + t.Helper() + level = strings.TrimSpace(level) + if level == "" { + level = "debug" + } + if err := logger.Init(logger.InitOptions{ + Level: level, + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: false, + ToFile: false, + }, + }); err != nil { + t.Fatalf("init logger: %v", err) + } + sink := &testLogSink{} + logger.SetSink(sink) + t.Cleanup(func() { + logger.SetSink(nil) + }) + return sink +} + +func TestRequestLogger_GenerateAndPropagateRequestID(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(RequestLogger()) + r.GET("/t", func(c *gin.Context) { + reqID, ok := c.Request.Context().Value(ctxkey.RequestID).(string) + if !ok || reqID == "" { + t.Fatalf("request_id missing in context") + } + if got := c.Writer.Header().Get(requestIDHeader); got != reqID { + t.Fatalf("response header request_id mismatch, header=%q ctx=%q", got, reqID) + } + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("status=%d", w.Code) + } + if w.Header().Get(requestIDHeader) == "" { + t.Fatalf("X-Request-ID should be set") + } +} + +func TestRequestLogger_KeepIncomingRequestID(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(RequestLogger()) + r.GET("/t", func(c *gin.Context) { + reqID, _ := c.Request.Context().Value(ctxkey.RequestID).(string) + if reqID != "rid-fixed" { + t.Fatalf("request_id=%q, want rid-fixed", reqID) + } + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set(requestIDHeader, "rid-fixed") + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("status=%d", w.Code) + } + if got := w.Header().Get(requestIDHeader); got != "rid-fixed" { + t.Fatalf("header=%q, want rid-fixed", got) + } +} + +func TestLogger_AccessLogIncludesCoreFields(t *testing.T) { + gin.SetMode(gin.TestMode) + sink := initMiddlewareTestLogger(t) + + r := gin.New() + r.Use(Logger()) + r.Use(func(c *gin.Context) { + ctx := c.Request.Context() + ctx = context.WithValue(ctx, ctxkey.AccountID, int64(101)) + ctx = context.WithValue(ctx, ctxkey.Platform, "openai") + ctx = context.WithValue(ctx, ctxkey.Model, "gpt-5") + c.Request = c.Request.WithContext(ctx) + c.Next() + }) + r.GET("/api/test", func(c *gin.Context) { + c.Status(http.StatusCreated) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/test", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusCreated { + t.Fatalf("status=%d", w.Code) + } + + events := sink.list() + if len(events) == 0 { + t.Fatalf("expected at least one log event") + } + found := false + for _, event := range events { + if event == nil || event.Message != "http request completed" { + continue + } + found = true + switch v := event.Fields["status_code"].(type) { + case int: + if v != http.StatusCreated { + t.Fatalf("status_code field mismatch: %v", v) + } + case int64: + if v != int64(http.StatusCreated) { + t.Fatalf("status_code field mismatch: %v", v) + } + default: + t.Fatalf("status_code type mismatch: %T", v) + } + switch v := event.Fields["account_id"].(type) { + case int64: + if v != 101 { + t.Fatalf("account_id field mismatch: %v", v) + } + case int: + if v != 101 { + t.Fatalf("account_id field mismatch: %v", v) + } + default: + t.Fatalf("account_id type mismatch: %T", v) + } + if event.Fields["platform"] != "openai" || event.Fields["model"] != "gpt-5" { + t.Fatalf("platform/model mismatch: %+v", event.Fields) + } + } + if !found { + t.Fatalf("access log event not found") + } +} + +func TestLogger_HealthPathSkipped(t *testing.T) { + gin.SetMode(gin.TestMode) + sink := initMiddlewareTestLogger(t) + + r := gin.New() + r.Use(Logger()) + r.GET("/health", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/health", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("status=%d", w.Code) + } + if len(sink.list()) != 0 { + t.Fatalf("health endpoint should not write access log") + } +} + +func TestLogger_AccessLogDroppedWhenLevelWarn(t *testing.T) { + gin.SetMode(gin.TestMode) + sink := initMiddlewareTestLoggerWithLevel(t, "warn") + + r := gin.New() + r.Use(RequestLogger()) + r.Use(Logger()) + r.GET("/api/test", func(c *gin.Context) { + c.Status(http.StatusCreated) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/test", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusCreated { + t.Fatalf("status=%d", w.Code) + } + + events := sink.list() + for _, event := range events { + if event != nil && event.Message == "http request completed" { + t.Fatalf("access log should not be indexed when level=warn: %+v", event) + } + } +} diff --git a/backend/internal/server/middleware/request_logger.go b/backend/internal/server/middleware/request_logger.go new file mode 100644 index 00000000..0fb2feca --- /dev/null +++ b/backend/internal/server/middleware/request_logger.go @@ -0,0 +1,45 @@ +package middleware + +import ( + "context" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "go.uber.org/zap" +) + +const requestIDHeader = "X-Request-ID" + +// RequestLogger 在请求入口注入 request-scoped logger。 +func RequestLogger() gin.HandlerFunc { + return func(c *gin.Context) { + if c.Request == nil { + c.Next() + return + } + + requestID := strings.TrimSpace(c.GetHeader(requestIDHeader)) + if requestID == "" { + requestID = uuid.NewString() + } + c.Header(requestIDHeader, requestID) + + ctx := context.WithValue(c.Request.Context(), ctxkey.RequestID, requestID) + clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string) + + requestLogger := logger.With( + zap.String("component", "http"), + zap.String("request_id", requestID), + zap.String("client_request_id", strings.TrimSpace(clientRequestID)), + zap.String("path", c.Request.URL.Path), + zap.String("method", c.Request.Method), + ) + + ctx = logger.IntoContext(ctx, requestLogger) + c.Request = c.Request.WithContext(ctx) + c.Next() + } +} diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go index 9ce7f449..d9ec951e 100644 --- a/backend/internal/server/middleware/security_headers.go +++ b/backend/internal/server/middleware/security_headers.go @@ -3,6 +3,8 @@ package middleware import ( "crypto/rand" "encoding/base64" + "fmt" + "log" "strings" "github.com/Wei-Shaw/sub2api/internal/config" @@ -18,11 +20,14 @@ const ( CloudflareInsightsDomain = "https://static.cloudflareinsights.com" ) -// GenerateNonce generates a cryptographically secure random nonce -func GenerateNonce() string { +// GenerateNonce generates a cryptographically secure random nonce. +// 返回 error 以确保调用方在 crypto/rand 失败时能正确降级。 +func GenerateNonce() (string, error) { b := make([]byte, 16) - _, _ = rand.Read(b) - return base64.StdEncoding.EncodeToString(b) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("generate CSP nonce: %w", err) + } + return base64.StdEncoding.EncodeToString(b), nil } // GetNonceFromContext retrieves the CSP nonce from gin context @@ -36,7 +41,9 @@ func GetNonceFromContext(c *gin.Context) string { } // SecurityHeaders sets baseline security headers for all responses. -func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { +// getFrameSrcOrigins is an optional function that returns extra origins to inject into frame-src; +// pass nil to disable dynamic frame-src injection. +func SecurityHeaders(cfg config.CSPConfig, getFrameSrcOrigins func() []string) gin.HandlerFunc { policy := strings.TrimSpace(cfg.Policy) if policy == "" { policy = config.DefaultCSPPolicy @@ -46,23 +53,51 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { policy = enhanceCSPPolicy(policy) return func(c *gin.Context) { + finalPolicy := policy + if getFrameSrcOrigins != nil { + for _, origin := range getFrameSrcOrigins() { + if origin != "" { + finalPolicy = addToDirective(finalPolicy, "frame-src", origin) + } + } + } + c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Frame-Options", "DENY") c.Header("Referrer-Policy", "strict-origin-when-cross-origin") + if isAPIRoutePath(c) { + c.Next() + return + } if cfg.Enabled { // Generate nonce for this request - nonce := GenerateNonce() - c.Set(CSPNonceKey, nonce) - - // Replace nonce placeholder in policy - finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'") - c.Header("Content-Security-Policy", finalPolicy) + nonce, err := GenerateNonce() + if err != nil { + // crypto/rand 失败时降级为无 nonce 的 CSP 策略 + log.Printf("[SecurityHeaders] %v — 降级为无 nonce 的 CSP", err) + c.Header("Content-Security-Policy", strings.ReplaceAll(finalPolicy, NonceTemplate, "'unsafe-inline'")) + } else { + c.Set(CSPNonceKey, nonce) + c.Header("Content-Security-Policy", strings.ReplaceAll(finalPolicy, NonceTemplate, "'nonce-"+nonce+"'")) + } } c.Next() } } +func isAPIRoutePath(c *gin.Context) bool { + if c == nil || c.Request == nil || c.Request.URL == nil { + return false + } + path := c.Request.URL.Path + return strings.HasPrefix(path, "/v1/") || + strings.HasPrefix(path, "/v1beta/") || + strings.HasPrefix(path, "/antigravity/") || + strings.HasPrefix(path, "/sora/") || + strings.HasPrefix(path, "/responses") +} + // enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain. // This allows the application to work correctly even if the config file has an older CSP policy. func enhanceCSPPolicy(policy string) string { diff --git a/backend/internal/server/middleware/security_headers_test.go b/backend/internal/server/middleware/security_headers_test.go index dc7a87d8..031385d0 100644 --- a/backend/internal/server/middleware/security_headers_test.go +++ b/backend/internal/server/middleware/security_headers_test.go @@ -19,7 +19,8 @@ func init() { func TestGenerateNonce(t *testing.T) { t.Run("generates_valid_base64_string", func(t *testing.T) { - nonce := GenerateNonce() + nonce, err := GenerateNonce() + require.NoError(t, err) // Should be valid base64 decoded, err := base64.StdEncoding.DecodeString(nonce) @@ -32,14 +33,16 @@ func TestGenerateNonce(t *testing.T) { t.Run("generates_unique_nonces", func(t *testing.T) { nonces := make(map[string]bool) for i := 0; i < 100; i++ { - nonce := GenerateNonce() + nonce, err := GenerateNonce() + require.NoError(t, err) assert.False(t, nonces[nonce], "nonce should be unique") nonces[nonce] = true } }) t.Run("nonce_has_expected_length", func(t *testing.T) { - nonce := GenerateNonce() + nonce, err := GenerateNonce() + require.NoError(t, err) // 16 bytes -> 24 chars in base64 (with padding) assert.Len(t, nonce, 24) }) @@ -81,7 +84,7 @@ func TestGetNonceFromContext(t *testing.T) { func TestSecurityHeaders(t *testing.T) { t.Run("sets_basic_security_headers", func(t *testing.T) { cfg := config.CSPConfig{Enabled: false} - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -96,7 +99,7 @@ func TestSecurityHeaders(t *testing.T) { t.Run("csp_disabled_no_csp_header", func(t *testing.T) { cfg := config.CSPConfig{Enabled: false} - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -112,7 +115,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "default-src 'self'", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -128,12 +131,32 @@ func TestSecurityHeaders(t *testing.T) { assert.Contains(t, csp, CloudflareInsightsDomain) }) + t.Run("api_route_skips_csp_nonce_generation", func(t *testing.T) { + cfg := config.CSPConfig{ + Enabled: true, + Policy: "default-src 'self'; script-src 'self' __CSP_NONCE__", + } + middleware := SecurityHeaders(cfg, nil) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + middleware(c) + + assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options")) + assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options")) + assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy")) + assert.Empty(t, w.Header().Get("Content-Security-Policy")) + assert.Empty(t, GetNonceFromContext(c)) + }) + t.Run("csp_enabled_with_nonce_placeholder", func(t *testing.T) { cfg := config.CSPConfig{ Enabled: true, Policy: "script-src 'self' __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -157,7 +180,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -176,7 +199,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: " \t\n ", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -194,7 +217,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "script-src __CSP_NONCE__; style-src __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -212,7 +235,7 @@ func TestSecurityHeaders(t *testing.T) { t.Run("calls_next_handler", func(t *testing.T) { cfg := config.CSPConfig{Enabled: true, Policy: "default-src 'self'"} - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) nextCalled := false router := gin.New() @@ -235,7 +258,7 @@ func TestSecurityHeaders(t *testing.T) { Enabled: true, Policy: "script-src __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) nonces := make(map[string]bool) for i := 0; i < 10; i++ { @@ -344,7 +367,7 @@ func TestAddToDirective(t *testing.T) { // Benchmark tests func BenchmarkGenerateNonce(b *testing.B) { for i := 0; i < b.N; i++ { - GenerateNonce() + _, _ = GenerateNonce() } } @@ -353,7 +376,7 @@ func BenchmarkSecurityHeadersMiddleware(b *testing.B) { Enabled: true, Policy: "script-src 'self' __CSP_NONCE__", } - middleware := SecurityHeaders(cfg) + middleware := SecurityHeaders(cfg, nil) b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index cf9015e4..571986b4 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -1,7 +1,10 @@ package server import ( + "context" "log" + "sync/atomic" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" @@ -14,6 +17,8 @@ import ( "github.com/redis/go-redis/v9" ) +const frameSrcRefreshTimeout = 5 * time.Second + // SetupRouter 配置路由器中间件和路由 func SetupRouter( r *gin.Engine, @@ -28,10 +33,33 @@ func SetupRouter( cfg *config.Config, redisClient *redis.Client, ) *gin.Engine { + // 缓存 iframe 页面的 origin 列表,用于动态注入 CSP frame-src + var cachedFrameOrigins atomic.Pointer[[]string] + emptyOrigins := []string{} + cachedFrameOrigins.Store(&emptyOrigins) + + refreshFrameOrigins := func() { + ctx, cancel := context.WithTimeout(context.Background(), frameSrcRefreshTimeout) + defer cancel() + origins, err := settingService.GetFrameSrcOrigins(ctx) + if err != nil { + // 获取失败时保留已有缓存,避免 frame-src 被意外清空 + return + } + cachedFrameOrigins.Store(&origins) + } + refreshFrameOrigins() // 启动时初始化 + // 应用中间件 + r.Use(middleware2.RequestLogger()) r.Use(middleware2.Logger()) r.Use(middleware2.CORS(cfg.CORS)) - r.Use(middleware2.SecurityHeaders(cfg.Security.CSP)) + r.Use(middleware2.SecurityHeaders(cfg.Security.CSP, func() []string { + if p := cachedFrameOrigins.Load(); p != nil { + return *p + } + return nil + })) // Serve embedded frontend with settings injection if available if web.HasEmbeddedFrontend() { @@ -39,15 +67,21 @@ func SetupRouter( if err != nil { log.Printf("Warning: Failed to create frontend server with settings injection: %v, using legacy mode", err) r.Use(web.ServeEmbeddedFrontend()) + settingService.SetOnUpdateCallback(refreshFrameOrigins) } else { - // Register cache invalidation callback - settingService.SetOnUpdateCallback(frontendServer.InvalidateCache) + // Register combined callback: invalidate HTML cache + refresh frame origins + settingService.SetOnUpdateCallback(func() { + frontendServer.InvalidateCache() + refreshFrameOrigins() + }) r.Use(frontendServer.Middleware()) } + } else { + settingService.SetOnUpdateCallback(refreshFrameOrigins) } // 注册路由 - registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg, redisClient) + registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient) return r } @@ -62,6 +96,7 @@ func registerRoutes( apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, opsService *service.OpsService, + settingService *service.SettingService, cfg *config.Config, redisClient *redis.Client, ) { @@ -74,6 +109,7 @@ func registerRoutes( // 注册各模块路由 routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient) routes.RegisterUserRoutes(v1, h, jwtAuth) + routes.RegisterSoraClientRoutes(v1, h, jwtAuth) routes.RegisterAdminRoutes(v1, h, adminAuth) - routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg) + routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg) } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 4509b4bc..2b6077c1 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -34,6 +34,8 @@ func RegisterAdminRoutes( // OpenAI OAuth registerOpenAIOAuthRoutes(admin, h) + // Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立) + registerSoraOAuthRoutes(admin, h) // Gemini OAuth registerGeminiOAuthRoutes(admin, h) @@ -53,6 +55,9 @@ func RegisterAdminRoutes( // 系统设置 registerSettingsRoutes(admin, h) + // 数据管理 + registerDataManagementRoutes(admin, h) + // 运维监控(Ops) registerOpsRoutes(admin, h) @@ -70,6 +75,16 @@ func RegisterAdminRoutes( // 错误透传规则管理 registerErrorPassthroughRoutes(admin, h) + + // API Key 管理 + registerAdminAPIKeyRoutes(admin, h) + } +} + +func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + apiKeys := admin.Group("/api-keys") + { + apiKeys.PUT("/:id", h.Admin.APIKey.UpdateGroup) } } @@ -101,6 +116,9 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { { runtime.GET("/alert", h.Admin.Ops.GetAlertRuntimeSettings) runtime.PUT("/alert", h.Admin.Ops.UpdateAlertRuntimeSettings) + runtime.GET("/logging", h.Admin.Ops.GetRuntimeLogConfig) + runtime.PUT("/logging", h.Admin.Ops.UpdateRuntimeLogConfig) + runtime.POST("/logging/reset", h.Admin.Ops.ResetRuntimeLogConfig) } // Advanced settings (DB-backed) @@ -144,22 +162,31 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { // Request drilldown (success + error) ops.GET("/requests", h.Admin.Ops.ListRequestDetails) + // Indexed system logs + ops.GET("/system-logs", h.Admin.Ops.ListSystemLogs) + ops.POST("/system-logs/cleanup", h.Admin.Ops.CleanupSystemLogs) + ops.GET("/system-logs/health", h.Admin.Ops.GetSystemLogIngestionHealth) + // Dashboard (vNext - raw path for MVP) + ops.GET("/dashboard/snapshot-v2", h.Admin.Ops.GetDashboardSnapshotV2) ops.GET("/dashboard/overview", h.Admin.Ops.GetDashboardOverview) ops.GET("/dashboard/throughput-trend", h.Admin.Ops.GetDashboardThroughputTrend) ops.GET("/dashboard/latency-histogram", h.Admin.Ops.GetDashboardLatencyHistogram) ops.GET("/dashboard/error-trend", h.Admin.Ops.GetDashboardErrorTrend) ops.GET("/dashboard/error-distribution", h.Admin.Ops.GetDashboardErrorDistribution) + ops.GET("/dashboard/openai-token-stats", h.Admin.Ops.GetDashboardOpenAITokenStats) } } func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) { dashboard := admin.Group("/dashboard") { + dashboard.GET("/snapshot-v2", h.Admin.Dashboard.GetSnapshotV2) dashboard.GET("/stats", h.Admin.Dashboard.GetStats) dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics) dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend) dashboard.GET("/models", h.Admin.Dashboard.GetModelStats) + dashboard.GET("/groups", h.Admin.Dashboard.GetGroupStats) dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend) dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend) dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage) @@ -208,6 +235,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.GET("", h.Admin.Account.List) accounts.GET("/:id", h.Admin.Account.GetByID) accounts.POST("", h.Admin.Account.Create) + accounts.POST("/check-mixed-channel", h.Admin.Account.CheckMixedChannel) accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS) accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS) accounts.PUT("/:id", h.Admin.Account.Update) @@ -219,6 +247,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.POST("/:id/clear-error", h.Admin.Account.ClearError) accounts.GET("/:id/usage", h.Admin.Account.GetUsage) accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats) + accounts.POST("/today-stats/batch", h.Admin.Account.GetBatchTodayStats) accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit) accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable) accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable) @@ -267,6 +296,19 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { } } +func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + sora := admin.Group("/sora") + { + sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL) + sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode) + sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken) + sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken) + sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken) + sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken) + sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth) + } +} + func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { gemini := admin.Group("/gemini") { @@ -297,6 +339,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) { proxies.PUT("/:id", h.Admin.Proxy.Update) proxies.DELETE("/:id", h.Admin.Proxy.Delete) proxies.POST("/:id/test", h.Admin.Proxy.Test) + proxies.POST("/:id/quality-check", h.Admin.Proxy.CheckQuality) proxies.GET("/:id/stats", h.Admin.Proxy.GetStats) proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts) proxies.POST("/batch-delete", h.Admin.Proxy.BatchDelete) @@ -311,6 +354,7 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) { codes.GET("/stats", h.Admin.Redeem.GetStats) codes.GET("/export", h.Admin.Redeem.Export) codes.GET("/:id", h.Admin.Redeem.GetByID) + codes.POST("/create-and-redeem", h.Admin.Redeem.CreateAndRedeem) codes.POST("/generate", h.Admin.Redeem.Generate) codes.DELETE("/:id", h.Admin.Redeem.Delete) codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete) @@ -344,6 +388,38 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { // 流超时处理配置 adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings) adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings) + // Sora S3 存储配置 + adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings) + adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings) + adminSettings.POST("/sora-s3/test", h.Admin.Setting.TestSoraS3Connection) + adminSettings.GET("/sora-s3/profiles", h.Admin.Setting.ListSoraS3Profiles) + adminSettings.POST("/sora-s3/profiles", h.Admin.Setting.CreateSoraS3Profile) + adminSettings.PUT("/sora-s3/profiles/:profile_id", h.Admin.Setting.UpdateSoraS3Profile) + adminSettings.DELETE("/sora-s3/profiles/:profile_id", h.Admin.Setting.DeleteSoraS3Profile) + adminSettings.POST("/sora-s3/profiles/:profile_id/activate", h.Admin.Setting.SetActiveSoraS3Profile) + } +} + +func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + dataManagement := admin.Group("/data-management") + { + dataManagement.GET("/agent/health", h.Admin.DataManagement.GetAgentHealth) + dataManagement.GET("/config", h.Admin.DataManagement.GetConfig) + dataManagement.PUT("/config", h.Admin.DataManagement.UpdateConfig) + dataManagement.GET("/sources/:source_type/profiles", h.Admin.DataManagement.ListSourceProfiles) + dataManagement.POST("/sources/:source_type/profiles", h.Admin.DataManagement.CreateSourceProfile) + dataManagement.PUT("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.UpdateSourceProfile) + dataManagement.DELETE("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.DeleteSourceProfile) + dataManagement.POST("/sources/:source_type/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveSourceProfile) + dataManagement.POST("/s3/test", h.Admin.DataManagement.TestS3) + dataManagement.GET("/s3/profiles", h.Admin.DataManagement.ListS3Profiles) + dataManagement.POST("/s3/profiles", h.Admin.DataManagement.CreateS3Profile) + dataManagement.PUT("/s3/profiles/:profile_id", h.Admin.DataManagement.UpdateS3Profile) + dataManagement.DELETE("/s3/profiles/:profile_id", h.Admin.DataManagement.DeleteS3Profile) + dataManagement.POST("/s3/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveS3Profile) + dataManagement.POST("/backups", h.Admin.DataManagement.CreateBackupJob) + dataManagement.GET("/backups", h.Admin.DataManagement.ListBackupJobs) + dataManagement.GET("/backups/:job_id", h.Admin.DataManagement.GetBackupJob) } } diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 26d79605..c168820c 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -24,10 +24,19 @@ func RegisterAuthRoutes( // 公开接口 auth := v1.Group("/auth") { - auth.POST("/register", h.Auth.Register) - auth.POST("/login", h.Auth.Login) - auth.POST("/login/2fa", h.Auth.Login2FA) - auth.POST("/send-verify-code", h.Auth.SendVerifyCode) + // 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close) + auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.Register) + auth.POST("/login", rateLimiter.LimitWithOptions("auth-login", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.Login) + auth.POST("/login/2fa", rateLimiter.LimitWithOptions("auth-login-2fa", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.Login2FA) + auth.POST("/send-verify-code", rateLimiter.LimitWithOptions("auth-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.SendVerifyCode) // Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close) auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, diff --git a/backend/internal/server/routes/auth_rate_limit_integration_test.go b/backend/internal/server/routes/auth_rate_limit_integration_test.go new file mode 100644 index 00000000..8a0ef860 --- /dev/null +++ b/backend/internal/server/routes/auth_rate_limit_integration_test.go @@ -0,0 +1,111 @@ +//go:build integration + +package routes + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" + tcredis "github.com/testcontainers/testcontainers-go/modules/redis" +) + +const authRouteRedisImageTag = "redis:8.4-alpine" + +func TestAuthRegisterRateLimitThresholdHitReturns429(t *testing.T) { + ctx := context.Background() + rdb := startAuthRouteRedis(t, ctx) + + router := newAuthRoutesTestRouter(rdb) + const path = "/api/v1/auth/register" + + for i := 1; i <= 6; i++ { + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = "198.51.100.10:23456" + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if i <= 5 { + require.Equal(t, http.StatusBadRequest, w.Code, "第 %d 次请求应先进入业务校验", i) + continue + } + require.Equal(t, http.StatusTooManyRequests, w.Code, "第 6 次请求应命中限流") + require.Contains(t, w.Body.String(), "rate limit exceeded") + } +} + +func startAuthRouteRedis(t *testing.T, ctx context.Context) *redis.Client { + t.Helper() + ensureAuthRouteDockerAvailable(t) + + redisContainer, err := tcredis.Run(ctx, authRouteRedisImageTag) + require.NoError(t, err) + t.Cleanup(func() { + _ = redisContainer.Terminate(ctx) + }) + + redisHost, err := redisContainer.Host(ctx) + require.NoError(t, err) + redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp") + require.NoError(t, err) + + rdb := redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()), + DB: 0, + }) + require.NoError(t, rdb.Ping(ctx).Err()) + t.Cleanup(func() { + _ = rdb.Close() + }) + return rdb +} + +func ensureAuthRouteDockerAvailable(t *testing.T) { + t.Helper() + if authRouteDockerAvailable() { + return + } + t.Skip("Docker 未启用,跳过认证限流集成测试") +} + +func authRouteDockerAvailable() bool { + if os.Getenv("DOCKER_HOST") != "" { + return true + } + + socketCandidates := []string{ + "/var/run/docker.sock", + filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"), + filepath.Join(authRouteUserHomeDir(), ".docker", "run", "docker.sock"), + filepath.Join(authRouteUserHomeDir(), ".docker", "desktop", "docker.sock"), + filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"), + } + + for _, socket := range socketCandidates { + if socket == "" { + continue + } + if _, err := os.Stat(socket); err == nil { + return true + } + } + return false +} + +func authRouteUserHomeDir() string { + home, err := os.UserHomeDir() + if err != nil { + return "" + } + return home +} diff --git a/backend/internal/server/routes/auth_rate_limit_test.go b/backend/internal/server/routes/auth_rate_limit_test.go new file mode 100644 index 00000000..5ce8497c --- /dev/null +++ b/backend/internal/server/routes/auth_rate_limit_test.go @@ -0,0 +1,67 @@ +package routes + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/handler" + servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/gin-gonic/gin" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + v1 := router.Group("/api/v1") + + RegisterAuthRoutes( + v1, + &handler.Handlers{ + Auth: &handler.AuthHandler{}, + Setting: &handler.SettingHandler{}, + }, + servermiddleware.JWTAuthMiddleware(func(c *gin.Context) { + c.Next() + }), + redisClient, + ) + + return router +} + +func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Addr: "127.0.0.1:1", + DialTimeout: 50 * time.Millisecond, + ReadTimeout: 50 * time.Millisecond, + WriteTimeout: 50 * time.Millisecond, + }) + t.Cleanup(func() { + _ = rdb.Close() + }) + + router := newAuthRoutesTestRouter(rdb) + paths := []string{ + "/api/v1/auth/register", + "/api/v1/auth/login", + "/api/v1/auth/login/2fa", + "/api/v1/auth/send-verify-code", + } + + for _, path := range paths { + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = "203.0.113.10:12345" + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusTooManyRequests, w.Code, "path=%s", path) + require.Contains(t, w.Body.String(), "rate limit exceeded", "path=%s", path) + } +} diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index bf019ce3..13f13320 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -1,6 +1,8 @@ package routes import ( + "net/http" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -17,18 +19,29 @@ func RegisterGatewayRoutes( apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, opsService *service.OpsService, + settingService *service.SettingService, cfg *config.Config, ) { bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize) + soraMaxBodySize := cfg.Gateway.SoraMaxBodySize + if soraMaxBodySize <= 0 { + soraMaxBodySize = cfg.Gateway.MaxBodySize + } + soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize) clientRequestID := middleware.ClientRequestID() opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService) + // 未分组 Key 拦截中间件(按协议格式区分错误响应) + requireGroupAnthropic := middleware.RequireGroupAssignment(settingService, middleware.AnthropicErrorWriter) + requireGroupGoogle := middleware.RequireGroupAssignment(settingService, middleware.GoogleErrorWriter) + // API网关(Claude API兼容) gateway := r.Group("/v1") gateway.Use(bodyLimit) gateway.Use(clientRequestID) gateway.Use(opsErrorLogger) gateway.Use(gin.HandlerFunc(apiKeyAuth)) + gateway.Use(requireGroupAnthropic) { gateway.POST("/messages", h.Gateway.Messages) gateway.POST("/messages/count_tokens", h.Gateway.CountTokens) @@ -36,6 +49,16 @@ func RegisterGatewayRoutes( gateway.GET("/usage", h.Gateway.Usage) // OpenAI Responses API gateway.POST("/responses", h.OpenAIGateway.Responses) + gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket) + // 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。 + gateway.POST("/chat/completions", func(c *gin.Context) { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": "Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses.", + }, + }) + }) } // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) @@ -44,6 +67,7 @@ func RegisterGatewayRoutes( gemini.Use(clientRequestID) gemini.Use(opsErrorLogger) gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) + gemini.Use(requireGroupGoogle) { gemini.GET("/models", h.Gateway.GeminiV1BetaListModels) gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) @@ -52,10 +76,11 @@ func RegisterGatewayRoutes( } // OpenAI Responses API(不带v1前缀的别名) - r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses) + r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses) + r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket) // Antigravity 模型列表 - r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels) + r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels) // Antigravity 专用路由(仅使用 antigravity 账户,不混合调度) antigravityV1 := r.Group("/antigravity/v1") @@ -64,6 +89,7 @@ func RegisterGatewayRoutes( antigravityV1.Use(opsErrorLogger) antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1.Use(gin.HandlerFunc(apiKeyAuth)) + antigravityV1.Use(requireGroupAnthropic) { antigravityV1.POST("/messages", h.Gateway.Messages) antigravityV1.POST("/messages/count_tokens", h.Gateway.CountTokens) @@ -77,9 +103,32 @@ func RegisterGatewayRoutes( antigravityV1Beta.Use(opsErrorLogger) antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity)) antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) + antigravityV1Beta.Use(requireGroupGoogle) { antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels) antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels) } + + // Sora 专用路由(强制使用 sora 平台) + soraV1 := r.Group("/sora/v1") + soraV1.Use(soraBodyLimit) + soraV1.Use(clientRequestID) + soraV1.Use(opsErrorLogger) + soraV1.Use(middleware.ForcePlatform(service.PlatformSora)) + soraV1.Use(gin.HandlerFunc(apiKeyAuth)) + soraV1.Use(requireGroupAnthropic) + { + soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions) + soraV1.GET("/models", h.Gateway.Models) + } + + // Sora 媒体代理(可选 API Key 验证) + if cfg.Gateway.SoraMediaRequireAPIKey { + r.GET("/sora/media/*filepath", gin.HandlerFunc(apiKeyAuth), h.SoraGateway.MediaProxy) + } else { + r.GET("/sora/media/*filepath", h.SoraGateway.MediaProxy) + } + // Sora 媒体代理(签名 URL,无需 API Key) + r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned) } diff --git a/backend/internal/server/routes/sora_client.go b/backend/internal/server/routes/sora_client.go new file mode 100644 index 00000000..40ae0436 --- /dev/null +++ b/backend/internal/server/routes/sora_client.go @@ -0,0 +1,33 @@ +package routes + +import ( + "github.com/Wei-Shaw/sub2api/internal/handler" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + + "github.com/gin-gonic/gin" +) + +// RegisterSoraClientRoutes 注册 Sora 客户端 API 路由(需要用户认证)。 +func RegisterSoraClientRoutes( + v1 *gin.RouterGroup, + h *handler.Handlers, + jwtAuth middleware.JWTAuthMiddleware, +) { + if h.SoraClient == nil { + return + } + + authenticated := v1.Group("/sora") + authenticated.Use(gin.HandlerFunc(jwtAuth)) + { + authenticated.POST("/generate", h.SoraClient.Generate) + authenticated.GET("/generations", h.SoraClient.ListGenerations) + authenticated.GET("/generations/:id", h.SoraClient.GetGeneration) + authenticated.DELETE("/generations/:id", h.SoraClient.DeleteGeneration) + authenticated.POST("/generations/:id/cancel", h.SoraClient.CancelGeneration) + authenticated.POST("/generations/:id/save", h.SoraClient.SaveToStorage) + authenticated.GET("/quota", h.SoraClient.GetQuota) + authenticated.GET("/models", h.SoraClient.GetModels) + authenticated.GET("/storage-status", h.SoraClient.GetStorageStatus) + } +} diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 138d5bcb..81e91aeb 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -3,11 +3,14 @@ package service import ( "encoding/json" + "hash/fnv" + "reflect" "sort" "strconv" "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/domain" ) @@ -50,6 +53,14 @@ type Account struct { AccountGroups []AccountGroup GroupIDs []int64 Groups []*Group + + // model_mapping 热路径缓存(非持久化字段) + modelMappingCache map[string]string + modelMappingCacheReady bool + modelMappingCacheCredentialsPtr uintptr + modelMappingCacheRawPtr uintptr + modelMappingCacheRawLen int + modelMappingCacheRawSig uint64 } type TempUnschedulableRule struct { @@ -349,6 +360,39 @@ func parseTempUnschedInt(value any) int { } func (a *Account) GetModelMapping() map[string]string { + credentialsPtr := mapPtr(a.Credentials) + rawMapping, _ := a.Credentials["model_mapping"].(map[string]any) + rawPtr := mapPtr(rawMapping) + rawLen := len(rawMapping) + rawSig := uint64(0) + rawSigReady := false + + if a.modelMappingCacheReady && + a.modelMappingCacheCredentialsPtr == credentialsPtr && + a.modelMappingCacheRawPtr == rawPtr && + a.modelMappingCacheRawLen == rawLen { + rawSig = modelMappingSignature(rawMapping) + rawSigReady = true + if a.modelMappingCacheRawSig == rawSig { + return a.modelMappingCache + } + } + + mapping := a.resolveModelMapping(rawMapping) + if !rawSigReady { + rawSig = modelMappingSignature(rawMapping) + } + + a.modelMappingCache = mapping + a.modelMappingCacheReady = true + a.modelMappingCacheCredentialsPtr = credentialsPtr + a.modelMappingCacheRawPtr = rawPtr + a.modelMappingCacheRawLen = rawLen + a.modelMappingCacheRawSig = rawSig + return mapping +} + +func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]string { if a.Credentials == nil { // Antigravity 平台使用默认映射 if a.Platform == domain.PlatformAntigravity { @@ -356,25 +400,31 @@ func (a *Account) GetModelMapping() map[string]string { } return nil } - raw, ok := a.Credentials["model_mapping"] - if !ok || raw == nil { + if len(rawMapping) == 0 { // Antigravity 平台使用默认映射 if a.Platform == domain.PlatformAntigravity { return domain.DefaultAntigravityModelMapping } return nil } - if m, ok := raw.(map[string]any); ok { - result := make(map[string]string) - for k, v := range m { - if s, ok := v.(string); ok { - result[k] = s - } - } - if len(result) > 0 { - return result + + result := make(map[string]string) + for k, v := range rawMapping { + if s, ok := v.(string); ok { + result[k] = s } } + if len(result) > 0 { + if a.Platform == domain.PlatformAntigravity { + ensureAntigravityDefaultPassthroughs(result, []string{ + "gemini-3-flash", + "gemini-3.1-pro-high", + "gemini-3.1-pro-low", + }) + } + return result + } + // Antigravity 平台使用默认映射 if a.Platform == domain.PlatformAntigravity { return domain.DefaultAntigravityModelMapping @@ -382,6 +432,58 @@ func (a *Account) GetModelMapping() map[string]string { return nil } +func mapPtr(m map[string]any) uintptr { + if m == nil { + return 0 + } + return reflect.ValueOf(m).Pointer() +} + +func modelMappingSignature(rawMapping map[string]any) uint64 { + if len(rawMapping) == 0 { + return 0 + } + keys := make([]string, 0, len(rawMapping)) + for k := range rawMapping { + keys = append(keys, k) + } + sort.Strings(keys) + + h := fnv.New64a() + for _, k := range keys { + _, _ = h.Write([]byte(k)) + _, _ = h.Write([]byte{0}) + if v, ok := rawMapping[k].(string); ok { + _, _ = h.Write([]byte(v)) + } else { + _, _ = h.Write([]byte{1}) + } + _, _ = h.Write([]byte{0xff}) + } + return h.Sum64() +} + +func ensureAntigravityDefaultPassthrough(mapping map[string]string, model string) { + if mapping == nil || model == "" { + return + } + if _, exists := mapping[model]; exists { + return + } + for pattern := range mapping { + if matchWildcard(pattern, model) { + return + } + } + mapping[model] = model +} + +func ensureAntigravityDefaultPassthroughs(mapping map[string]string, models []string) { + for _, model := range models { + ensureAntigravityDefaultPassthrough(mapping, model) + } +} + // IsModelSupported 检查模型是否在 model_mapping 中(支持通配符) // 如果未配置 mapping,返回 true(允许所有模型) func (a *Account) IsModelSupported(requestedModel string) bool { @@ -696,6 +798,204 @@ func (a *Account) IsMixedSchedulingEnabled() bool { return false } +// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。 +// +// 新字段:accounts.extra.openai_passthrough。 +// 兼容字段:accounts.extra.openai_oauth_passthrough(历史 OAuth 开关)。 +// 字段缺失或类型不正确时,按 false(关闭)处理。 +func (a *Account) IsOpenAIPassthroughEnabled() bool { + if a == nil || !a.IsOpenAI() || a.Extra == nil { + return false + } + if enabled, ok := a.Extra["openai_passthrough"].(bool); ok { + return enabled + } + if enabled, ok := a.Extra["openai_oauth_passthrough"].(bool); ok { + return enabled + } + return false +} + +// IsOpenAIResponsesWebSocketV2Enabled 返回 OpenAI 账号是否开启 Responses WebSocket v2。 +// +// 分类型新字段: +// - OAuth 账号:accounts.extra.openai_oauth_responses_websockets_v2_enabled +// - API Key 账号:accounts.extra.openai_apikey_responses_websockets_v2_enabled +// +// 兼容字段: +// - accounts.extra.responses_websockets_v2_enabled +// - accounts.extra.openai_ws_enabled(历史开关) +// +// 优先级: +// 1. 按账号类型读取分类型字段 +// 2. 分类型字段缺失时,回退兼容字段 +func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool { + if a == nil || !a.IsOpenAI() || a.Extra == nil { + return false + } + if a.IsOpenAIOAuth() { + if enabled, ok := a.Extra["openai_oauth_responses_websockets_v2_enabled"].(bool); ok { + return enabled + } + } + if a.IsOpenAIApiKey() { + if enabled, ok := a.Extra["openai_apikey_responses_websockets_v2_enabled"].(bool); ok { + return enabled + } + } + if enabled, ok := a.Extra["responses_websockets_v2_enabled"].(bool); ok { + return enabled + } + if enabled, ok := a.Extra["openai_ws_enabled"].(bool); ok { + return enabled + } + return false +} + +const ( + OpenAIWSIngressModeOff = "off" + OpenAIWSIngressModeShared = "shared" + OpenAIWSIngressModeDedicated = "dedicated" +) + +func normalizeOpenAIWSIngressMode(mode string) string { + switch strings.ToLower(strings.TrimSpace(mode)) { + case OpenAIWSIngressModeOff: + return OpenAIWSIngressModeOff + case OpenAIWSIngressModeShared: + return OpenAIWSIngressModeShared + case OpenAIWSIngressModeDedicated: + return OpenAIWSIngressModeDedicated + default: + return "" + } +} + +func normalizeOpenAIWSIngressDefaultMode(mode string) string { + if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" { + return normalized + } + return OpenAIWSIngressModeShared +} + +// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。 +// +// 优先级: +// 1. 分类型 mode 新字段(string) +// 2. 分类型 enabled 旧字段(bool) +// 3. 兼容 enabled 旧字段(bool) +// 4. defaultMode(非法时回退 shared) +func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string { + resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode) + if a == nil || !a.IsOpenAI() { + return OpenAIWSIngressModeOff + } + if a.Extra == nil { + return resolvedDefault + } + + resolveModeString := func(key string) (string, bool) { + raw, ok := a.Extra[key] + if !ok { + return "", false + } + mode, ok := raw.(string) + if !ok { + return "", false + } + normalized := normalizeOpenAIWSIngressMode(mode) + if normalized == "" { + return "", false + } + return normalized, true + } + resolveBoolMode := func(key string) (string, bool) { + raw, ok := a.Extra[key] + if !ok { + return "", false + } + enabled, ok := raw.(bool) + if !ok { + return "", false + } + if enabled { + return OpenAIWSIngressModeShared, true + } + return OpenAIWSIngressModeOff, true + } + + if a.IsOpenAIOAuth() { + if mode, ok := resolveModeString("openai_oauth_responses_websockets_v2_mode"); ok { + return mode + } + if mode, ok := resolveBoolMode("openai_oauth_responses_websockets_v2_enabled"); ok { + return mode + } + } + if a.IsOpenAIApiKey() { + if mode, ok := resolveModeString("openai_apikey_responses_websockets_v2_mode"); ok { + return mode + } + if mode, ok := resolveBoolMode("openai_apikey_responses_websockets_v2_enabled"); ok { + return mode + } + } + if mode, ok := resolveBoolMode("responses_websockets_v2_enabled"); ok { + return mode + } + if mode, ok := resolveBoolMode("openai_ws_enabled"); ok { + return mode + } + return resolvedDefault +} + +// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。 +// 字段:accounts.extra.openai_ws_force_http。 +func (a *Account) IsOpenAIWSForceHTTPEnabled() bool { + if a == nil || !a.IsOpenAI() || a.Extra == nil { + return false + } + enabled, ok := a.Extra["openai_ws_force_http"].(bool) + return ok && enabled +} + +// IsOpenAIWSAllowStoreRecoveryEnabled 返回账号级 store 恢复开关。 +// 字段:accounts.extra.openai_ws_allow_store_recovery。 +func (a *Account) IsOpenAIWSAllowStoreRecoveryEnabled() bool { + if a == nil || !a.IsOpenAI() || a.Extra == nil { + return false + } + enabled, ok := a.Extra["openai_ws_allow_store_recovery"].(bool) + return ok && enabled +} + +// IsOpenAIOAuthPassthroughEnabled 兼容旧接口,等价于 OAuth 账号的 IsOpenAIPassthroughEnabled。 +func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool { + return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled() +} + +// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用“自动透传(仅替换认证)”。 +// 字段:accounts.extra.anthropic_passthrough。 +// 字段缺失或类型不正确时,按 false(关闭)处理。 +func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool { + if a == nil || a.Platform != PlatformAnthropic || a.Type != AccountTypeAPIKey || a.Extra == nil { + return false + } + enabled, ok := a.Extra["anthropic_passthrough"].(bool) + return ok && enabled +} + +// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用“仅允许 Codex 官方客户端”。 +// 字段:accounts.extra.codex_cli_only。 +// 字段缺失或类型不正确时,按 false(关闭)处理。 +func (a *Account) IsCodexCLIOnlyEnabled() bool { + if a == nil || !a.IsOpenAIOAuth() || a.Extra == nil { + return false + } + enabled, ok := a.Extra["codex_cli_only"].(bool) + return ok && enabled +} + // WindowCostSchedulability 窗口费用调度状态 type WindowCostSchedulability int @@ -733,6 +1033,26 @@ func (a *Account) IsTLSFingerprintEnabled() bool { return false } +// GetUserMsgQueueMode 获取用户消息队列模式 +// "serialize" = 串行队列, "throttle" = 软性限速, "" = 未设置(使用全局配置) +func (a *Account) GetUserMsgQueueMode() string { + if a.Extra == nil { + return "" + } + // 优先读取新字段 user_msg_queue_mode(白名单校验,非法值视为未设置) + if mode, ok := a.Extra["user_msg_queue_mode"].(string); ok && mode != "" { + if mode == config.UMQModeSerialize || mode == config.UMQModeThrottle { + return mode + } + return "" // 非法值 fallback 到全局配置 + } + // 向后兼容: user_msg_queue_enabled: true → "serialize" + if enabled, ok := a.Extra["user_msg_queue_enabled"].(bool); ok && enabled { + return config.UMQModeSerialize + } + return "" +} + // IsSessionIDMaskingEnabled 检查是否启用会话ID伪装 // 仅适用于 Anthropic OAuth/SetupToken 类型账号 // 启用后将在一段时间内(15分钟)固定 metadata.user_id 中的 session ID, @@ -752,6 +1072,38 @@ func (a *Account) IsSessionIDMaskingEnabled() bool { return false } +// IsCacheTTLOverrideEnabled 检查是否启用缓存 TTL 强制替换 +// 仅适用于 Anthropic OAuth/SetupToken 类型账号 +// 启用后将所有 cache creation tokens 归入指定的 TTL 类型(5m 或 1h) +func (a *Account) IsCacheTTLOverrideEnabled() bool { + if !a.IsAnthropicOAuthOrSetupToken() { + return false + } + if a.Extra == nil { + return false + } + if v, ok := a.Extra["cache_ttl_override_enabled"]; ok { + if enabled, ok := v.(bool); ok { + return enabled + } + } + return false +} + +// GetCacheTTLOverrideTarget 获取缓存 TTL 强制替换的目标类型 +// 返回 "5m" 或 "1h",默认 "5m" +func (a *Account) GetCacheTTLOverrideTarget() string { + if a.Extra == nil { + return "5m" + } + if v, ok := a.Extra["cache_ttl_override_target"]; ok { + if target, ok := v.(string); ok && (target == "5m" || target == "1h") { + return target + } + } + return "5m" +} + // GetWindowCostLimit 获取 5h 窗口费用阈值(美元) // 返回 0 表示未启用 func (a *Account) GetWindowCostLimit() float64 { @@ -806,6 +1158,80 @@ func (a *Account) GetSessionIdleTimeoutMinutes() int { return 5 } +// GetBaseRPM 获取基础 RPM 限制 +// 返回 0 表示未启用(负数视为无效配置,按 0 处理) +func (a *Account) GetBaseRPM() int { + if a.Extra == nil { + return 0 + } + if v, ok := a.Extra["base_rpm"]; ok { + val := parseExtraInt(v) + if val > 0 { + return val + } + } + return 0 +} + +// GetRPMStrategy 获取 RPM 策略 +// "tiered" = 三区模型(默认), "sticky_exempt" = 粘性豁免 +func (a *Account) GetRPMStrategy() string { + if a.Extra == nil { + return "tiered" + } + if v, ok := a.Extra["rpm_strategy"]; ok { + if s, ok := v.(string); ok && s == "sticky_exempt" { + return "sticky_exempt" + } + } + return "tiered" +} + +// GetRPMStickyBuffer 获取 RPM 粘性缓冲数量 +// tiered 模式下的黄区大小,默认为 base_rpm 的 20%(至少 1) +func (a *Account) GetRPMStickyBuffer() int { + if a.Extra == nil { + return 0 + } + if v, ok := a.Extra["rpm_sticky_buffer"]; ok { + val := parseExtraInt(v) + if val > 0 { + return val + } + } + base := a.GetBaseRPM() + buffer := base / 5 + if buffer < 1 && base > 0 { + buffer = 1 + } + return buffer +} + +// CheckRPMSchedulability 根据当前 RPM 计数检查调度状态 +// 复用 WindowCostSchedulability 三态:Schedulable / StickyOnly / NotSchedulable +func (a *Account) CheckRPMSchedulability(currentRPM int) WindowCostSchedulability { + baseRPM := a.GetBaseRPM() + if baseRPM <= 0 { + return WindowCostSchedulable + } + + if currentRPM < baseRPM { + return WindowCostSchedulable + } + + strategy := a.GetRPMStrategy() + if strategy == "sticky_exempt" { + return WindowCostStickyOnly // 粘性豁免无红区 + } + + // tiered: 黄区 + 红区 + buffer := a.GetRPMStickyBuffer() + if currentRPM < baseRPM+buffer { + return WindowCostStickyOnly + } + return WindowCostNotSchedulable +} + // CheckWindowCostSchedulability 根据当前窗口费用检查调度状态 // - 费用 < 阈值: WindowCostSchedulable(可正常调度) // - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly(仅粘性会话) @@ -869,6 +1295,12 @@ func parseExtraFloat64(value any) float64 { } // parseExtraInt 从 extra 字段解析 int 值 +// ParseExtraInt 从 extra 字段的 any 值解析为 int。 +// 支持 int, int64, float64, json.Number, string 类型,无法解析时返回 0。 +func ParseExtraInt(value any) int { + return parseExtraInt(value) +} + func parseExtraInt(value any) int { switch v := value.(type) { case int: diff --git a/backend/internal/service/account_anthropic_passthrough_test.go b/backend/internal/service/account_anthropic_passthrough_test.go new file mode 100644 index 00000000..e66407a3 --- /dev/null +++ b/backend/internal/service/account_anthropic_passthrough_test.go @@ -0,0 +1,62 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAccount_IsAnthropicAPIKeyPassthroughEnabled(t *testing.T) { + t.Run("Anthropic API Key 开启", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "anthropic_passthrough": true, + }, + } + require.True(t, account.IsAnthropicAPIKeyPassthroughEnabled()) + }) + + t.Run("Anthropic API Key 关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "anthropic_passthrough": false, + }, + } + require.False(t, account.IsAnthropicAPIKeyPassthroughEnabled()) + }) + + t.Run("字段类型非法默认关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "anthropic_passthrough": "true", + }, + } + require.False(t, account.IsAnthropicAPIKeyPassthroughEnabled()) + }) + + t.Run("非 Anthropic API Key 账号始终关闭", func(t *testing.T) { + oauth := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "anthropic_passthrough": true, + }, + } + require.False(t, oauth.IsAnthropicAPIKeyPassthroughEnabled()) + + openai := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "anthropic_passthrough": true, + }, + } + require.False(t, openai.IsAnthropicAPIKeyPassthroughEnabled()) + }) +} diff --git a/backend/internal/service/account_intercept_warmup_test.go b/backend/internal/service/account_intercept_warmup_test.go new file mode 100644 index 00000000..f117fd8d --- /dev/null +++ b/backend/internal/service/account_intercept_warmup_test.go @@ -0,0 +1,66 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAccount_IsInterceptWarmupEnabled(t *testing.T) { + tests := []struct { + name string + credentials map[string]any + expected bool + }{ + { + name: "nil credentials", + credentials: nil, + expected: false, + }, + { + name: "empty map", + credentials: map[string]any{}, + expected: false, + }, + { + name: "field not present", + credentials: map[string]any{"access_token": "tok"}, + expected: false, + }, + { + name: "field is true", + credentials: map[string]any{"intercept_warmup_requests": true}, + expected: true, + }, + { + name: "field is false", + credentials: map[string]any{"intercept_warmup_requests": false}, + expected: false, + }, + { + name: "field is string true", + credentials: map[string]any{"intercept_warmup_requests": "true"}, + expected: false, + }, + { + name: "field is int 1", + credentials: map[string]any{"intercept_warmup_requests": 1}, + expected: false, + }, + { + name: "field is nil", + credentials: map[string]any{"intercept_warmup_requests": nil}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Account{Credentials: tt.credentials} + result := a.IsInterceptWarmupEnabled() + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/backend/internal/service/account_openai_passthrough_test.go b/backend/internal/service/account_openai_passthrough_test.go new file mode 100644 index 00000000..a85c68ec --- /dev/null +++ b/backend/internal/service/account_openai_passthrough_test.go @@ -0,0 +1,294 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAccount_IsOpenAIPassthroughEnabled(t *testing.T) { + t.Run("新字段开启", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "openai_passthrough": true, + }, + } + require.True(t, account.IsOpenAIPassthroughEnabled()) + }) + + t.Run("兼容旧字段", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_passthrough": true, + }, + } + require.True(t, account.IsOpenAIPassthroughEnabled()) + }) + + t.Run("非OpenAI账号始终关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_passthrough": true, + }, + } + require.False(t, account.IsOpenAIPassthroughEnabled()) + }) + + t.Run("空额外配置默认关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + require.False(t, account.IsOpenAIPassthroughEnabled()) + }) +} + +func TestAccount_IsOpenAIOAuthPassthroughEnabled(t *testing.T) { + t.Run("仅OAuth类型允许返回开启", func(t *testing.T) { + oauthAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_passthrough": true, + }, + } + require.True(t, oauthAccount.IsOpenAIOAuthPassthroughEnabled()) + + apiKeyAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "openai_passthrough": true, + }, + } + require.False(t, apiKeyAccount.IsOpenAIOAuthPassthroughEnabled()) + }) +} + +func TestAccount_IsCodexCLIOnlyEnabled(t *testing.T) { + t.Run("OpenAI OAuth 开启", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "codex_cli_only": true, + }, + } + require.True(t, account.IsCodexCLIOnlyEnabled()) + }) + + t.Run("OpenAI OAuth 关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "codex_cli_only": false, + }, + } + require.False(t, account.IsCodexCLIOnlyEnabled()) + }) + + t.Run("字段缺失默认关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{}, + } + require.False(t, account.IsCodexCLIOnlyEnabled()) + }) + + t.Run("类型非法默认关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "codex_cli_only": "true", + }, + } + require.False(t, account.IsCodexCLIOnlyEnabled()) + }) + + t.Run("非 OAuth 账号始终关闭", func(t *testing.T) { + apiKeyAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "codex_cli_only": true, + }, + } + require.False(t, apiKeyAccount.IsCodexCLIOnlyEnabled()) + + otherPlatform := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "codex_cli_only": true, + }, + } + require.False(t, otherPlatform.IsCodexCLIOnlyEnabled()) + }) +} + +func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) { + t.Run("OAuth使用OAuth专用开关", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + }, + } + require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) + + t.Run("API Key使用API Key专用开关", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) + + t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) + + t.Run("分类型新键优先于兼容键", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": false, + "responses_websockets_v2_enabled": true, + "openai_ws_enabled": true, + }, + } + require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) + + t.Run("分类型键缺失时回退兼容键", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) + + t.Run("非OpenAI账号默认关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) +} + +func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) { + t.Run("default fallback to shared", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{}, + } + require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("")) + require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid")) + }) + + t.Run("oauth mode field has highest priority", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + "openai_oauth_responses_websockets_v2_enabled": false, + "responses_websockets_v2_enabled": false, + }, + } + require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared)) + }) + + t.Run("legacy enabled maps to shared", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + }) + + t.Run("legacy disabled maps to off", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": false, + "responses_websockets_v2_enabled": true, + }, + } + require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared)) + }) + + t.Run("non openai always off", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + }, + } + require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeDedicated)) + }) +} + +func TestAccount_OpenAIWSExtraFlags(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_ws_force_http": true, + "openai_ws_allow_store_recovery": true, + }, + } + require.True(t, account.IsOpenAIWSForceHTTPEnabled()) + require.True(t, account.IsOpenAIWSAllowStoreRecoveryEnabled()) + + off := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}} + require.False(t, off.IsOpenAIWSForceHTTPEnabled()) + require.False(t, off.IsOpenAIWSAllowStoreRecoveryEnabled()) + + var nilAccount *Account + require.False(t, nilAccount.IsOpenAIWSAllowStoreRecoveryEnabled()) + + nonOpenAI := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_ws_allow_store_recovery": true, + }, + } + require.False(t, nonOpenAI.IsOpenAIWSAllowStoreRecoveryEnabled()) +} diff --git a/backend/internal/service/account_rpm_test.go b/backend/internal/service/account_rpm_test.go new file mode 100644 index 00000000..9d91f3e0 --- /dev/null +++ b/backend/internal/service/account_rpm_test.go @@ -0,0 +1,120 @@ +package service + +import ( + "encoding/json" + "testing" +) + +func TestGetBaseRPM(t *testing.T) { + tests := []struct { + name string + extra map[string]any + expected int + }{ + {"nil extra", nil, 0}, + {"no key", map[string]any{}, 0}, + {"zero", map[string]any{"base_rpm": 0}, 0}, + {"int value", map[string]any{"base_rpm": 15}, 15}, + {"float value", map[string]any{"base_rpm": 15.0}, 15}, + {"string value", map[string]any{"base_rpm": "15"}, 15}, + {"negative value", map[string]any{"base_rpm": -5}, 0}, + {"int64 value", map[string]any{"base_rpm": int64(20)}, 20}, + {"json.Number value", map[string]any{"base_rpm": json.Number("25")}, 25}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Account{Extra: tt.extra} + if got := a.GetBaseRPM(); got != tt.expected { + t.Errorf("GetBaseRPM() = %d, want %d", got, tt.expected) + } + }) + } +} + +func TestGetRPMStrategy(t *testing.T) { + tests := []struct { + name string + extra map[string]any + expected string + }{ + {"nil extra", nil, "tiered"}, + {"no key", map[string]any{}, "tiered"}, + {"tiered", map[string]any{"rpm_strategy": "tiered"}, "tiered"}, + {"sticky_exempt", map[string]any{"rpm_strategy": "sticky_exempt"}, "sticky_exempt"}, + {"invalid", map[string]any{"rpm_strategy": "foobar"}, "tiered"}, + {"empty string fallback", map[string]any{"rpm_strategy": ""}, "tiered"}, + {"numeric value fallback", map[string]any{"rpm_strategy": 123}, "tiered"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Account{Extra: tt.extra} + if got := a.GetRPMStrategy(); got != tt.expected { + t.Errorf("GetRPMStrategy() = %q, want %q", got, tt.expected) + } + }) + } +} + +func TestCheckRPMSchedulability(t *testing.T) { + tests := []struct { + name string + extra map[string]any + currentRPM int + expected WindowCostSchedulability + }{ + {"disabled", map[string]any{}, 100, WindowCostSchedulable}, + {"green zone", map[string]any{"base_rpm": 15}, 10, WindowCostSchedulable}, + {"yellow zone tiered", map[string]any{"base_rpm": 15}, 15, WindowCostStickyOnly}, + {"red zone tiered", map[string]any{"base_rpm": 15}, 18, WindowCostNotSchedulable}, + {"sticky_exempt at limit", map[string]any{"base_rpm": 15, "rpm_strategy": "sticky_exempt"}, 15, WindowCostStickyOnly}, + {"sticky_exempt over limit", map[string]any{"base_rpm": 15, "rpm_strategy": "sticky_exempt"}, 100, WindowCostStickyOnly}, + {"custom buffer", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 14, WindowCostStickyOnly}, + {"custom buffer red", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 15, WindowCostNotSchedulable}, + {"base_rpm=1 green", map[string]any{"base_rpm": 1}, 0, WindowCostSchedulable}, + {"base_rpm=1 yellow (at limit)", map[string]any{"base_rpm": 1}, 1, WindowCostStickyOnly}, + {"base_rpm=1 red (at limit+buffer)", map[string]any{"base_rpm": 1}, 2, WindowCostNotSchedulable}, + {"negative currentRPM", map[string]any{"base_rpm": 15}, -1, WindowCostSchedulable}, + {"base_rpm negative disabled", map[string]any{"base_rpm": -5}, 10, WindowCostSchedulable}, + {"very high currentRPM", map[string]any{"base_rpm": 10}, 9999, WindowCostNotSchedulable}, + {"sticky_exempt very high currentRPM", map[string]any{"base_rpm": 10, "rpm_strategy": "sticky_exempt"}, 9999, WindowCostStickyOnly}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Account{Extra: tt.extra} + if got := a.CheckRPMSchedulability(tt.currentRPM); got != tt.expected { + t.Errorf("CheckRPMSchedulability(%d) = %d, want %d", tt.currentRPM, got, tt.expected) + } + }) + } +} + +func TestGetRPMStickyBuffer(t *testing.T) { + tests := []struct { + name string + extra map[string]any + expected int + }{ + {"nil extra", nil, 0}, + {"no keys", map[string]any{}, 0}, + {"base_rpm=0", map[string]any{"base_rpm": 0}, 0}, + {"base_rpm=1 min buffer 1", map[string]any{"base_rpm": 1}, 1}, + {"base_rpm=4 min buffer 1", map[string]any{"base_rpm": 4}, 1}, + {"base_rpm=5 buffer 1", map[string]any{"base_rpm": 5}, 1}, + {"base_rpm=10 buffer 2", map[string]any{"base_rpm": 10}, 2}, + {"base_rpm=15 buffer 3", map[string]any{"base_rpm": 15}, 3}, + {"base_rpm=100 buffer 20", map[string]any{"base_rpm": 100}, 20}, + {"custom buffer=5", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 5}, + {"custom buffer=0 fallback to default", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 0}, 2}, + {"custom buffer negative fallback", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": -1}, 2}, + {"custom buffer with float", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": float64(7)}, 7}, + {"json.Number base_rpm", map[string]any{"base_rpm": json.Number("10")}, 2}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Account{Extra: tt.extra} + if got := a.GetRPMStickyBuffer(); got != tt.expected { + t.Errorf("GetRPMStickyBuffer() = %d, want %d", got, tt.expected) + } + }) + } +} diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index f192fba4..18a70c5c 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -25,6 +25,9 @@ type AccountRepository interface { // GetByCRSAccountID finds an account previously synced from CRS. // Returns (nil, nil) if not found. GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) + // FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora') + // 用于查找通过 linked_openai_account_id 关联的 Sora 账号 + FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) // ListCRSAccountIDs returns a map of crs_account_id -> local account ID // for all accounts that have been synced from CRS. ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) @@ -51,6 +54,8 @@ type AccountRepository interface { ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) + ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) + ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error @@ -116,6 +121,10 @@ type AccountService struct { groupRepo GroupRepository } +type groupExistenceBatchChecker interface { + ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) +} + // NewAccountService 创建账号服务实例 func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) *AccountService { return &AccountService{ @@ -128,11 +137,8 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) { // 验证分组是否存在(如果指定了分组) if len(req.GroupIDs) > 0 { - for _, groupID := range req.GroupIDs { - _, err := s.groupRepo.GetByID(ctx, groupID) - if err != nil { - return nil, fmt.Errorf("get group: %w", err) - } + if err := s.validateGroupIDsExist(ctx, req.GroupIDs); err != nil { + return nil, err } } @@ -253,11 +259,8 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount // 先验证分组是否存在(在任何写操作之前) if req.GroupIDs != nil { - for _, groupID := range *req.GroupIDs { - _, err := s.groupRepo.GetByID(ctx, groupID) - if err != nil { - return nil, fmt.Errorf("get group: %w", err) - } + if err := s.validateGroupIDsExist(ctx, *req.GroupIDs); err != nil { + return nil, err } } @@ -297,6 +300,39 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error { return nil } +func (s *AccountService) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error { + if len(groupIDs) == 0 { + return nil + } + if s.groupRepo == nil { + return fmt.Errorf("group repository not configured") + } + + if batchChecker, ok := s.groupRepo.(groupExistenceBatchChecker); ok { + existsByID, err := batchChecker.ExistsByIDs(ctx, groupIDs) + if err != nil { + return fmt.Errorf("check groups exists: %w", err) + } + for _, groupID := range groupIDs { + if groupID <= 0 { + return fmt.Errorf("get group: %w", ErrGroupNotFound) + } + if !existsByID[groupID] { + return fmt.Errorf("get group: %w", ErrGroupNotFound) + } + } + return nil + } + + for _, groupID := range groupIDs { + _, err := s.groupRepo.GetByID(ctx, groupID) + if err != nil { + return fmt.Errorf("get group: %w", err) + } + } + return nil +} + // UpdateStatus 更新账号状态 func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error { account, err := s.accountRepo.GetByID(ctx, id) diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index a420d46b..768cf7b7 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -54,6 +54,10 @@ func (s *accountRepoStub) GetByCRSAccountID(ctx context.Context, crsAccountID st panic("unexpected GetByCRSAccountID call") } +func (s *accountRepoStub) FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) { + panic("unexpected FindByExtraField call") +} + func (s *accountRepoStub) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { panic("unexpected ListCRSAccountIDs call") } @@ -143,6 +147,14 @@ func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte panic("unexpected ListSchedulableByGroupIDAndPlatforms call") } +func (s *accountRepoStub) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + panic("unexpected ListSchedulableUngroupedByPlatform call") +} + +func (s *accountRepoStub) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + panic("unexpected ListSchedulableUngroupedByPlatforms call") +} + func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { panic("unexpected SetRateLimited call") } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 899a4498..c55e418d 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -12,13 +12,17 @@ import ( "io" "log" "net/http" + "net/url" "regexp" "strings" + "sync" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/util/soraerror" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" "github.com/google/uuid" @@ -31,6 +35,11 @@ var sseDataPrefix = regexp.MustCompile(`^data:\s*`) const ( testClaudeAPIURL = "https://api.anthropic.com/v1/messages" chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses" + soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接 + soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions" + soraInviteMineURL = "https://sora.chatgpt.com/backend/project_y/invite/mine" + soraBootstrapURL = "https://sora.chatgpt.com/backend/m/bootstrap" + soraRemainingURL = "https://sora.chatgpt.com/backend/nf/check" ) // TestEvent represents a SSE event for account testing @@ -38,6 +47,9 @@ type TestEvent struct { Type string `json:"type"` Text string `json:"text,omitempty"` Model string `json:"model,omitempty"` + Status string `json:"status,omitempty"` + Code string `json:"code,omitempty"` + Data any `json:"data,omitempty"` Success bool `json:"success,omitempty"` Error string `json:"error,omitempty"` } @@ -49,8 +61,13 @@ type AccountTestService struct { antigravityGatewayService *AntigravityGatewayService httpUpstream HTTPUpstream cfg *config.Config + soraTestGuardMu sync.Mutex + soraTestLastRun map[int64]time.Time + soraTestCooldown time.Duration } +const defaultSoraTestCooldown = 10 * time.Second + // NewAccountTestService creates a new AccountTestService func NewAccountTestService( accountRepo AccountRepository, @@ -65,6 +82,8 @@ func NewAccountTestService( antigravityGatewayService: antigravityGatewayService, httpUpstream: httpUpstream, cfg: cfg, + soraTestLastRun: make(map[int64]time.Time), + soraTestCooldown: defaultSoraTestCooldown, } } @@ -163,6 +182,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int return s.testAntigravityAccountConnection(c, account, modelID) } + if account.Platform == PlatformSora { + return s.testSoraAccountConnection(c, account) + } + return s.testClaudeAccountConnection(c, account, modelID) } @@ -462,6 +485,697 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account return s.processGeminiStream(c, resp.Body) } +type soraProbeStep struct { + Name string `json:"name"` + Status string `json:"status"` + HTTPStatus int `json:"http_status,omitempty"` + ErrorCode string `json:"error_code,omitempty"` + Message string `json:"message,omitempty"` +} + +type soraProbeSummary struct { + Status string `json:"status"` + Steps []soraProbeStep `json:"steps"` +} + +type soraProbeRecorder struct { + steps []soraProbeStep +} + +func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) { + r.steps = append(r.steps, soraProbeStep{ + Name: name, + Status: status, + HTTPStatus: httpStatus, + ErrorCode: strings.TrimSpace(errorCode), + Message: strings.TrimSpace(message), + }) +} + +func (r *soraProbeRecorder) finalize() soraProbeSummary { + meSuccess := false + partial := false + for _, step := range r.steps { + if step.Name == "me" { + meSuccess = strings.EqualFold(step.Status, "success") + continue + } + if strings.EqualFold(step.Status, "failed") { + partial = true + } + } + + status := "success" + if !meSuccess { + status = "failed" + } else if partial { + status = "partial_success" + } + + return soraProbeSummary{ + Status: status, + Steps: append([]soraProbeStep(nil), r.steps...), + } +} + +func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) { + if rec == nil { + return + } + summary := rec.finalize() + code := "" + for _, step := range summary.Steps { + if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" { + code = step.ErrorCode + break + } + } + s.sendEvent(c, TestEvent{ + Type: "sora_test_result", + Status: summary.Status, + Code: code, + Data: summary, + }) +} + +func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) { + if accountID <= 0 { + return 0, true + } + s.soraTestGuardMu.Lock() + defer s.soraTestGuardMu.Unlock() + + if s.soraTestLastRun == nil { + s.soraTestLastRun = make(map[int64]time.Time) + } + cooldown := s.soraTestCooldown + if cooldown <= 0 { + cooldown = defaultSoraTestCooldown + } + + now := time.Now() + if lastRun, ok := s.soraTestLastRun[accountID]; ok { + elapsed := now.Sub(lastRun) + if elapsed < cooldown { + return cooldown - elapsed, false + } + } + s.soraTestLastRun[accountID] = now + return 0, true +} + +func ceilSeconds(d time.Duration) int { + if d <= 0 { + return 1 + } + sec := int(d / time.Second) + if d%time.Second != 0 { + sec++ + } + if sec < 1 { + sec = 1 + } + return sec +} + +// testSoraAPIKeyAccountConnection 测试 Sora apikey 类型账号的连通性。 +// 向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性和 API Key 有效性。 +func (s *AccountTestService) testSoraAPIKeyAccountConnection(c *gin.Context, account *Account) error { + ctx := c.Request.Context() + + apiKey := account.GetCredential("api_key") + if apiKey == "" { + return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 api_key 凭证") + } + + baseURL := account.GetBaseURL() + if baseURL == "" { + return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 base_url") + } + + // 验证 base_url 格式 + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("base_url 无效: %s", err.Error())) + } + upstreamURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/sora/v1/chat/completions" + + // 设置 SSE 头 + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + if wait, ok := s.acquireSoraTestPermit(account.ID); !ok { + msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait)) + return s.sendErrorAndEnd(c, msg) + } + + s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora-upstream"}) + + // 构建轻量级 prompt-enhance 请求作为连通性测试 + testPayload := map[string]any{ + "model": "prompt-enhance-short-10s", + "messages": []map[string]string{{"role": "user", "content": "test"}}, + "stream": false, + } + payloadBytes, _ := json.Marshal(testPayload) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(payloadBytes)) + if err != nil { + return s.sendErrorAndEnd(c, "构建测试请求失败") + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + // 获取代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("上游连接失败: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) + + if resp.StatusCode == http.StatusOK { + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)}) + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效 (HTTP %d)", resp.StatusCode)}) + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + return s.sendErrorAndEnd(c, fmt.Sprintf("上游认证失败 (HTTP %d),请检查 API Key 是否正确", resp.StatusCode)) + } + + // 其他错误但能连通(如 400 参数错误)也算连通性测试通过 + if resp.StatusCode == http.StatusBadRequest { + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)}) + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效(上游返回 %d,参数校验错误属正常)", resp.StatusCode)}) + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + + return s.sendErrorAndEnd(c, fmt.Sprintf("上游返回异常 HTTP %d: %s", resp.StatusCode, truncateSoraErrorBody(respBody, 256))) +} + +// testSoraAccountConnection 测试 Sora 账号的连接 +// OAuth 类型:调用 /backend/me 接口验证 access_token 有效性 +// APIKey 类型:向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性 +func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error { + // apikey 类型走独立测试流程 + if account.Type == AccountTypeAPIKey { + return s.testSoraAPIKeyAccountConnection(c, account) + } + + ctx := c.Request.Context() + recorder := &soraProbeRecorder{} + + authToken := account.GetCredential("access_token") + if authToken == "" { + recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available") + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, "No access token available") + } + + // Set SSE headers + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + if wait, ok := s.acquireSoraTestPermit(account.ID); !ok { + msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait)) + recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg) + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, msg) + } + + // Send test_start event + s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"}) + + req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil) + if err != nil { + recorder.addStep("me", "failed", 0, "request_build_failed", err.Error()) + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, "Failed to create request") + } + + // 使用 Sora 客户端标准请求头 + req.Header.Set("Authorization", "Bearer "+authToken) + req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + req.Header.Set("Accept", "application/json") + req.Header.Set("Accept-Language", "en-US,en;q=0.9") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") + + // Get proxy URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + enableSoraTLSFingerprint := s.shouldEnableSoraTLSFingerprint() + + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint) + if err != nil { + recorder.addStep("me", "failed", 0, "network_error", err.Error()) + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) { + recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected") + s.emitSoraProbeSummary(c, recorder) + s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body) + return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body)) + } + upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body) + switch { + case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"): + recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated") + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, "Sora token 已失效(token_invalidated),请重新授权账号") + case strings.EqualFold(upstreamCode, "unsupported_country_code"): + recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region") + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用(unsupported_country_code),请切换到支持地区后重试") + case strings.TrimSpace(upstreamMessage) != "": + recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage) + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage)) + default: + recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed") + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512))) + } + } + recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok") + + // 解析 /me 响应,提取用户信息 + var meResp map[string]any + if err := json.Unmarshal(body, &meResp); err != nil { + // 能收到 200 就说明 token 有效 + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora connection OK (token valid)"}) + } else { + // 尝试提取用户名或邮箱信息 + info := "Sora connection OK" + if name, ok := meResp["name"].(string); ok && name != "" { + info = fmt.Sprintf("Sora connection OK - User: %s", name) + } else if email, ok := meResp["email"].(string); ok && email != "" { + info = fmt.Sprintf("Sora connection OK - Email: %s", email) + } + s.sendEvent(c, TestEvent{Type: "content", Text: info}) + } + + // 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试) + subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil) + if err == nil { + subReq.Header.Set("Authorization", "Bearer "+authToken) + subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + subReq.Header.Set("Accept", "application/json") + subReq.Header.Set("Accept-Language", "en-US,en;q=0.9") + subReq.Header.Set("Origin", "https://sora.chatgpt.com") + subReq.Header.Set("Referer", "https://sora.chatgpt.com/") + + subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint) + if subErr != nil { + recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error()) + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())}) + } else { + subBody, _ := io.ReadAll(subResp.Body) + _ = subResp.Body.Close() + if subResp.StatusCode == http.StatusOK { + recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok") + if summary := parseSoraSubscriptionSummary(subBody); summary != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: summary}) + } else { + s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"}) + } + } else { + if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) { + recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected") + s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody) + s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)}) + } else { + upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody) + recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage) + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)}) + } + } + } + } + + // 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。 + s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, enableSoraTLSFingerprint, recorder) + + s.emitSoraProbeSummary(c, recorder) + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil +} + +func (s *AccountTestService) testSora2Capabilities( + c *gin.Context, + ctx context.Context, + account *Account, + authToken string, + proxyURL string, + enableTLSFingerprint bool, + recorder *soraProbeRecorder, +) { + inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraInviteMineURL, + proxyURL, + enableTLSFingerprint, + ) + if err != nil { + if recorder != nil { + recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error()) + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())}) + return + } + + if inviteStatus == http.StatusUnauthorized { + bootstrapStatus, _, _, bootstrapErr := s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraBootstrapURL, + proxyURL, + enableTLSFingerprint, + ) + if bootstrapErr == nil && bootstrapStatus == http.StatusOK { + if recorder != nil { + recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok") + } + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"}) + inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraInviteMineURL, + proxyURL, + enableTLSFingerprint, + ) + if err != nil { + if recorder != nil { + recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error()) + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())}) + return + } + } else if recorder != nil { + code := "" + msg := "" + if bootstrapErr != nil { + code = "network_error" + msg = bootstrapErr.Error() + } + recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg) + } + } + + if inviteStatus != http.StatusOK { + if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) { + if recorder != nil { + recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected") + } + s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody) + s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)}) + return + } + upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody) + if recorder != nil { + recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage) + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)}) + return + } + if recorder != nil { + recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok") + } + + if summary := parseSoraInviteSummary(inviteBody); summary != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: summary}) + } else { + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 invite check OK"}) + } + + remainingStatus, remainingHeader, remainingBody, remainingErr := s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraRemainingURL, + proxyURL, + enableTLSFingerprint, + ) + if remainingErr != nil { + if recorder != nil { + recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error()) + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())}) + return + } + if remainingStatus != http.StatusOK { + if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) { + if recorder != nil { + recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected") + } + s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody) + s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)}) + return + } + upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody) + if recorder != nil { + recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage) + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)}) + return + } + if recorder != nil { + recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok") + } + if summary := parseSoraRemainingSummary(remainingBody); summary != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: summary}) + } else { + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 remaining check OK"}) + } +} + +func (s *AccountTestService) fetchSoraTestEndpoint( + ctx context.Context, + account *Account, + authToken string, + url string, + proxyURL string, + enableTLSFingerprint bool, +) (int, http.Header, []byte, error) { + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return 0, nil, nil, err + } + req.Header.Set("Authorization", "Bearer "+authToken) + req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + req.Header.Set("Accept", "application/json") + req.Header.Set("Accept-Language", "en-US,en;q=0.9") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") + + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableTLSFingerprint) + if err != nil { + return 0, nil, nil, err + } + defer func() { _ = resp.Body.Close() }() + + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return resp.StatusCode, resp.Header, nil, readErr + } + return resp.StatusCode, resp.Header, body, nil +} + +func parseSoraSubscriptionSummary(body []byte) string { + var subResp struct { + Data []struct { + Plan struct { + ID string `json:"id"` + Title string `json:"title"` + } `json:"plan"` + EndTS string `json:"end_ts"` + } `json:"data"` + } + if err := json.Unmarshal(body, &subResp); err != nil { + return "" + } + if len(subResp.Data) == 0 { + return "" + } + + first := subResp.Data[0] + parts := make([]string, 0, 3) + if first.Plan.Title != "" { + parts = append(parts, first.Plan.Title) + } + if first.Plan.ID != "" { + parts = append(parts, first.Plan.ID) + } + if first.EndTS != "" { + parts = append(parts, "end="+first.EndTS) + } + if len(parts) == 0 { + return "" + } + return "Subscription: " + strings.Join(parts, " | ") +} + +func parseSoraInviteSummary(body []byte) string { + var inviteResp struct { + InviteCode string `json:"invite_code"` + RedeemedCount int64 `json:"redeemed_count"` + TotalCount int64 `json:"total_count"` + } + if err := json.Unmarshal(body, &inviteResp); err != nil { + return "" + } + + parts := []string{"Sora2: supported"} + if inviteResp.InviteCode != "" { + parts = append(parts, "invite="+inviteResp.InviteCode) + } + if inviteResp.TotalCount > 0 { + parts = append(parts, fmt.Sprintf("used=%d/%d", inviteResp.RedeemedCount, inviteResp.TotalCount)) + } + return strings.Join(parts, " | ") +} + +func parseSoraRemainingSummary(body []byte) string { + var remainingResp struct { + RateLimitAndCreditBalance struct { + EstimatedNumVideosRemaining int64 `json:"estimated_num_videos_remaining"` + RateLimitReached bool `json:"rate_limit_reached"` + AccessResetsInSeconds int64 `json:"access_resets_in_seconds"` + } `json:"rate_limit_and_credit_balance"` + } + if err := json.Unmarshal(body, &remainingResp); err != nil { + return "" + } + info := remainingResp.RateLimitAndCreditBalance + parts := []string{fmt.Sprintf("Sora2 remaining: %d", info.EstimatedNumVideosRemaining)} + if info.RateLimitReached { + parts = append(parts, "rate_limited=true") + } + if info.AccessResetsInSeconds > 0 { + parts = append(parts, fmt.Sprintf("reset_in=%ds", info.AccessResetsInSeconds)) + } + return strings.Join(parts, " | ") +} + +func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool { + if s == nil || s.cfg == nil { + return true + } + return !s.cfg.Sora.Client.DisableTLSFingerprint +} + +func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool { + return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body) +} + +func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string { + return soraerror.FormatCloudflareChallengeMessage(base, headers, body) +} + +func extractCloudflareRayID(headers http.Header, body []byte) string { + return soraerror.ExtractCloudflareRayID(headers, body) +} + +func extractSoraEgressIPHint(headers http.Header) string { + if headers == nil { + return "unknown" + } + candidates := []string{ + "x-openai-public-ip", + "x-envoy-external-address", + "cf-connecting-ip", + "x-forwarded-for", + } + for _, key := range candidates { + if value := strings.TrimSpace(headers.Get(key)); value != "" { + return value + } + } + return "unknown" +} + +func sanitizeProxyURLForLog(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + u, err := url.Parse(raw) + if err != nil { + return "" + } + if u.User != nil { + u.User = nil + } + return u.String() +} + +func endpointPathForLog(endpoint string) string { + parsed, err := url.Parse(strings.TrimSpace(endpoint)) + if err != nil || parsed.Path == "" { + return endpoint + } + return parsed.Path +} + +func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyURL, endpoint string, headers http.Header, body []byte) { + accountID := int64(0) + platform := "" + proxyID := "none" + if account != nil { + accountID = account.ID + platform = account.Platform + if account.ProxyID != nil { + proxyID = fmt.Sprintf("%d", *account.ProxyID) + } + } + cfRay := extractCloudflareRayID(headers, body) + if cfRay == "" { + cfRay = "unknown" + } + log.Printf( + "[SoraCFChallenge] account_id=%d platform=%s endpoint=%s path=%s proxy_id=%s proxy_url=%s cf_ray=%s egress_ip_hint=%s", + accountID, + platform, + endpoint, + endpointPathForLog(endpoint), + proxyID, + sanitizeProxyURLForLog(proxyURL), + cfRay, + extractSoraEgressIPHint(headers), + ) +} + +func truncateSoraErrorBody(body []byte, max int) string { + return soraerror.TruncateBody(body, max) +} + // testAntigravityAccountConnection tests an Antigravity account's connection // 支持 Claude 和 Gemini 两种协议,使用非流式请求 func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error { diff --git a/backend/internal/service/account_test_service_sora_test.go b/backend/internal/service/account_test_service_sora_test.go new file mode 100644 index 00000000..3dfac786 --- /dev/null +++ b/backend/internal/service/account_test_service_sora_test.go @@ -0,0 +1,319 @@ +package service + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type queuedHTTPUpstream struct { + responses []*http.Response + requests []*http.Request + tlsFlags []bool +} + +func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + return nil, fmt.Errorf("unexpected Do call") +} + +func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, enableTLSFingerprint bool) (*http.Response, error) { + u.requests = append(u.requests, req) + u.tlsFlags = append(u.tlsFlags, enableTLSFingerprint) + if len(u.responses) == 0 { + return nil, fmt.Errorf("no mocked response") + } + resp := u.responses[0] + u.responses = u.responses[1:] + return resp, nil +} + +func newJSONResponse(status int, body string) *http.Response { + return &http.Response{ + StatusCode: status, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func newJSONResponseWithHeader(status int, body, key, value string) *http.Response { + resp := newJSONResponse(status, body) + resp.Header.Set(key, value) + return resp +} + +func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil) + return c, rec +} + +func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`), + newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`), + newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`), + newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`), + }, + } + svc := &AccountTestService{ + httpUpstream: upstream, + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + TLSFingerprint: config.TLSFingerprintConfig{ + Enabled: true, + }, + }, + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + DisableTLSFingerprint: false, + }, + }, + }, + } + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.NoError(t, err) + require.Len(t, upstream.requests, 4) + require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String()) + require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String()) + require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String()) + require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String()) + require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization")) + require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization")) + require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags) + + body := rec.Body.String() + require.Contains(t, body, `"type":"test_start"`) + require.Contains(t, body, "Sora connection OK - Email: demo@example.com") + require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z") + require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50") + require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s") + require.Contains(t, body, `"type":"sora_test_result"`) + require.Contains(t, body, `"status":"success"`) + require.Contains(t, body, `"type":"test_complete","success":true`) +} + +func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusOK, `{"name":"demo-user"}`), + newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`), + newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`), + newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.NoError(t, err) + require.Len(t, upstream.requests, 4) + body := rec.Body.String() + require.Contains(t, body, "Sora connection OK - User: demo-user") + require.Contains(t, body, "Subscription check returned 403") + require.Contains(t, body, "Sora2 invite check returned 401") + require.Contains(t, body, `"type":"sora_test_result"`) + require.Contains(t, body, `"status":"partial_success"`) + require.Contains(t, body, `"type":"test_complete","success":true`) +} + +func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponseWithHeader(http.StatusForbidden, `Just a moment...`, "cf-ray", "9cff2d62d83bb98d"), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.Error(t, err) + require.Contains(t, err.Error(), "Cloudflare challenge") + require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d") + body := rec.Body.String() + require.Contains(t, body, `"type":"error"`) + require.Contains(t, body, "Cloudflare challenge") + require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d") +} + +func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponseWithHeader(http.StatusTooManyRequests, `Just a moment...`, "cf-mitigated", "challenge"), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.Error(t, err) + require.Contains(t, err.Error(), "Cloudflare challenge") + require.Contains(t, err.Error(), "HTTP 429") + body := rec.Body.String() + require.Contains(t, body, "Cloudflare challenge") +} + +func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.Error(t, err) + require.Contains(t, err.Error(), "token_invalidated") + body := rec.Body.String() + require.Contains(t, body, `"type":"sora_test_result"`) + require.Contains(t, body, `"status":"failed"`) + require.Contains(t, body, "token_invalidated") + require.NotContains(t, body, `"type":"test_complete","success":true`) +} + +func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`), + }, + } + svc := &AccountTestService{ + httpUpstream: upstream, + soraTestCooldown: time.Hour, + } + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c1, _ := newSoraTestContext() + err := svc.testSoraAccountConnection(c1, account) + require.NoError(t, err) + + c2, rec2 := newSoraTestContext() + err = svc.testSoraAccountConnection(c2, account) + require.Error(t, err) + require.Contains(t, err.Error(), "测试过于频繁") + body := rec2.Body.String() + require.Contains(t, body, `"type":"sora_test_result"`) + require.Contains(t, body, `"code":"test_rate_limited"`) + require.Contains(t, body, `"status":"failed"`) + require.NotContains(t, body, `"type":"test_complete","success":true`) +} + +func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusOK, `{"name":"demo-user"}`), + newJSONResponse(http.StatusForbidden, `Just a moment...`), + newJSONResponse(http.StatusForbidden, `Just a moment...`), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.NoError(t, err) + body := rec.Body.String() + require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)") + require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)") + require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d") + require.Contains(t, body, `"type":"test_complete","success":true`) +} + +func TestSanitizeProxyURLForLog(t *testing.T) { + require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080")) + require.Equal(t, "", sanitizeProxyURLForLog("")) + require.Equal(t, "", sanitizeProxyURLForLog("://invalid")) +} + +func TestExtractSoraEgressIPHint(t *testing.T) { + h := make(http.Header) + h.Set("x-openai-public-ip", "203.0.113.10") + require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h)) + + h2 := make(http.Header) + h2.Set("x-envoy-external-address", "198.51.100.9") + require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2)) + + require.Equal(t, "unknown", extractSoraEgressIPHint(nil)) + require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{})) +} diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 304c5781..6dee6c13 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -4,11 +4,14 @@ import ( "context" "fmt" "log" + "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "golang.org/x/sync/errgroup" ) type UsageLogRepository interface { @@ -32,12 +35,13 @@ type UsageLogRepository interface { // Admin dashboard stats GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) - GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) - GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) + GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) + GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) + GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) - GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) - GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) + GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) + GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) // User dashboard stats GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) @@ -61,6 +65,10 @@ type UsageLogRepository interface { GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) } +type accountWindowStatsBatchReader interface { + GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) +} + // apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at) type apiUsageCache struct { response *ClaudeUsageResponse @@ -217,12 +225,20 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U } if account.Platform == PlatformGemini { - return s.getGeminiUsage(ctx, account) + usage, err := s.getGeminiUsage(ctx, account) + if err == nil { + s.tryClearRecoverableAccountError(ctx, account) + } + return usage, err } // Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度 if account.Platform == PlatformAntigravity { - return s.getAntigravityUsage(ctx, account) + usage, err := s.getAntigravityUsage(ctx, account) + if err == nil { + s.tryClearRecoverableAccountError(ctx, account) + } + return usage, err } // 只有oauth类型账号可以通过API获取usage(有profile scope) @@ -256,6 +272,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U // 4. 添加窗口统计(有独立缓存,1 分钟) s.addWindowStats(ctx, account, usage) + s.tryClearRecoverableAccountError(ctx, account) return usage, nil } @@ -287,7 +304,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou } dayStart := geminiDailyWindowStart(now) - stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil) + stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil, nil) if err != nil { return nil, fmt.Errorf("get gemini usage stats failed: %w", err) } @@ -309,7 +326,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou // Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m) minuteStart := now.Truncate(time.Minute) minuteResetAt := minuteStart.Add(time.Minute) - minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil) + minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil, nil) if err != nil { return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err) } @@ -430,6 +447,78 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64 }, nil } +// GetTodayStatsBatch 批量获取账号今日统计,优先走批量 SQL,失败时回退单账号查询。 +func (s *AccountUsageService) GetTodayStatsBatch(ctx context.Context, accountIDs []int64) (map[int64]*WindowStats, error) { + uniqueIDs := make([]int64, 0, len(accountIDs)) + seen := make(map[int64]struct{}, len(accountIDs)) + for _, accountID := range accountIDs { + if accountID <= 0 { + continue + } + if _, exists := seen[accountID]; exists { + continue + } + seen[accountID] = struct{}{} + uniqueIDs = append(uniqueIDs, accountID) + } + + result := make(map[int64]*WindowStats, len(uniqueIDs)) + if len(uniqueIDs) == 0 { + return result, nil + } + + startTime := timezone.Today() + if batchReader, ok := s.usageLogRepo.(accountWindowStatsBatchReader); ok { + statsByAccount, err := batchReader.GetAccountWindowStatsBatch(ctx, uniqueIDs, startTime) + if err == nil { + for _, accountID := range uniqueIDs { + result[accountID] = windowStatsFromAccountStats(statsByAccount[accountID]) + } + return result, nil + } + } + + var mu sync.Mutex + g, gctx := errgroup.WithContext(ctx) + g.SetLimit(8) + + for _, accountID := range uniqueIDs { + id := accountID + g.Go(func() error { + stats, err := s.usageLogRepo.GetAccountWindowStats(gctx, id, startTime) + if err != nil { + return nil + } + mu.Lock() + result[id] = windowStatsFromAccountStats(stats) + mu.Unlock() + return nil + }) + } + + _ = g.Wait() + + for _, accountID := range uniqueIDs { + if _, ok := result[accountID]; !ok { + result[accountID] = &WindowStats{} + } + } + return result, nil +} + +func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats { + if stats == nil { + return &WindowStats{} + } + return &WindowStats{ + Requests: stats.Requests, + Tokens: stats.Tokens, + Cost: stats.Cost, + StandardCost: stats.StandardCost, + UserCost: stats.UserCost, + } +} + func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) { stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime) if err != nil { @@ -486,6 +575,32 @@ func parseTime(s string) (time.Time, error) { return time.Time{}, fmt.Errorf("unable to parse time: %s", s) } +func (s *AccountUsageService) tryClearRecoverableAccountError(ctx context.Context, account *Account) { + if account == nil || account.Status != StatusError { + return + } + + msg := strings.ToLower(strings.TrimSpace(account.ErrorMessage)) + if msg == "" { + return + } + + if !strings.Contains(msg, "token refresh failed") && + !strings.Contains(msg, "invalid_client") && + !strings.Contains(msg, "missing_project_id") && + !strings.Contains(msg, "unauthenticated") { + return + } + + if err := s.accountRepo.ClearError(ctx, account.ID); err != nil { + log.Printf("[usage] failed to clear recoverable account error for account %d: %v", account.ID, err) + return + } + + account.Status = StatusActive + account.ErrorMessage = "" +} + // buildUsageInfo 构建UsageInfo func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedAt *time.Time) *UsageInfo { info := &UsageInfo{ diff --git a/backend/internal/service/account_wildcard_test.go b/backend/internal/service/account_wildcard_test.go index 90e5b573..7782f948 100644 --- a/backend/internal/service/account_wildcard_test.go +++ b/backend/internal/service/account_wildcard_test.go @@ -267,3 +267,119 @@ func TestAccountGetMappedModel(t *testing.T) { }) } } + +func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) { + account := &Account{ + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3-pro-high": "gemini-3.1-pro-high", + }, + }, + } + + mapping := account.GetModelMapping() + if mapping["gemini-3-flash"] != "gemini-3-flash" { + t.Fatalf("expected gemini-3-flash passthrough to be auto-filled, got: %q", mapping["gemini-3-flash"]) + } + if mapping["gemini-3.1-pro-high"] != "gemini-3.1-pro-high" { + t.Fatalf("expected gemini-3.1-pro-high passthrough to be auto-filled, got: %q", mapping["gemini-3.1-pro-high"]) + } + if mapping["gemini-3.1-pro-low"] != "gemini-3.1-pro-low" { + t.Fatalf("expected gemini-3.1-pro-low passthrough to be auto-filled, got: %q", mapping["gemini-3.1-pro-low"]) + } +} + +func TestAccountGetModelMapping_AntigravityRespectsWildcardOverride(t *testing.T) { + account := &Account{ + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-3*": "gemini-3.1-pro-high", + }, + }, + } + + mapping := account.GetModelMapping() + if _, exists := mapping["gemini-3-flash"]; exists { + t.Fatalf("did not expect explicit gemini-3-flash passthrough when wildcard already exists") + } + if _, exists := mapping["gemini-3.1-pro-high"]; exists { + t.Fatalf("did not expect explicit gemini-3.1-pro-high passthrough when wildcard already exists") + } + if _, exists := mapping["gemini-3.1-pro-low"]; exists { + t.Fatalf("did not expect explicit gemini-3.1-pro-low passthrough when wildcard already exists") + } + if mapped := account.GetMappedModel("gemini-3-flash"); mapped != "gemini-3.1-pro-high" { + t.Fatalf("expected wildcard mapping to stay effective, got: %q", mapped) + } +} + +func TestAccountGetModelMapping_CacheInvalidatesOnCredentialsReplace(t *testing.T) { + account := &Account{ + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-3-5-sonnet": "upstream-a", + }, + }, + } + + first := account.GetModelMapping() + if first["claude-3-5-sonnet"] != "upstream-a" { + t.Fatalf("unexpected first mapping: %v", first) + } + + account.Credentials = map[string]any{ + "model_mapping": map[string]any{ + "claude-3-5-sonnet": "upstream-b", + }, + } + second := account.GetModelMapping() + if second["claude-3-5-sonnet"] != "upstream-b" { + t.Fatalf("expected cache invalidated after credentials replace, got: %v", second) + } +} + +func TestAccountGetModelMapping_CacheInvalidatesOnMappingLenChange(t *testing.T) { + rawMapping := map[string]any{ + "claude-sonnet": "sonnet-a", + } + account := &Account{ + Credentials: map[string]any{ + "model_mapping": rawMapping, + }, + } + + first := account.GetModelMapping() + if len(first) != 1 { + t.Fatalf("unexpected first mapping length: %d", len(first)) + } + + rawMapping["claude-opus"] = "opus-b" + second := account.GetModelMapping() + if second["claude-opus"] != "opus-b" { + t.Fatalf("expected cache invalidated after mapping len change, got: %v", second) + } +} + +func TestAccountGetModelMapping_CacheInvalidatesOnInPlaceValueChange(t *testing.T) { + rawMapping := map[string]any{ + "claude-sonnet": "sonnet-a", + } + account := &Account{ + Credentials: map[string]any{ + "model_mapping": rawMapping, + }, + } + + first := account.GetModelMapping() + if first["claude-sonnet"] != "sonnet-a" { + t.Fatalf("unexpected first mapping: %v", first) + } + + rawMapping["claude-sonnet"] = "sonnet-b" + second := account.GetModelMapping() + if second["claude-sonnet"] != "sonnet-b" { + t.Fatalf("expected cache invalidated after in-place value change, got: %v", second) + } +} diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 1f6e91e5..67e7c783 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -4,11 +4,17 @@ import ( "context" "errors" "fmt" - "log" + "io" + "net/http" "strings" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/util/soraerror" ) // AdminService interface defines admin management operations @@ -38,6 +44,9 @@ type AdminService interface { GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error + // API Key management (admin) + AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) + // Account management ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) GetAccount(ctx context.Context, id int64) (*Account, error) @@ -50,6 +59,7 @@ type AdminService interface { SetAccountError(ctx context.Context, id int64, errorMsg string) error SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) + CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error // Proxy management ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) @@ -65,6 +75,7 @@ type AdminService interface { GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) + CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) // Redeem code management ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) @@ -77,13 +88,14 @@ type AdminService interface { // CreateUserInput represents input for creating a new user via admin operations. type CreateUserInput struct { - Email string - Password string - Username string - Notes string - Balance float64 - Concurrency int - AllowedGroups []int64 + Email string + Password string + Username string + Notes string + Balance float64 + Concurrency int + AllowedGroups []int64 + SoraStorageQuotaBytes int64 } type UpdateUserInput struct { @@ -97,7 +109,8 @@ type UpdateUserInput struct { AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" // GroupRates 用户专属分组倍率配置 // map[groupID]*rate,nil 表示删除该分组的专属倍率 - GroupRates map[int64]*float64 + GroupRates map[int64]*float64 + SoraStorageQuotaBytes *int64 } type CreateGroupInput struct { @@ -111,11 +124,16 @@ type CreateGroupInput struct { WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 - ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 - FallbackGroupID *int64 // 降级分组 ID + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + // Sora 按次计费配置 + SoraImagePrice360 *float64 + SoraImagePrice540 *float64 + SoraVideoPricePerRequest *float64 + SoraVideoPricePerRequestHD *float64 + ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID // 无效请求兜底分组 ID(仅 anthropic 平台使用) FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) @@ -124,6 +142,8 @@ type CreateGroupInput struct { MCPXMLInject *bool // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string + // Sora 存储配额 + SoraStorageQuotaBytes int64 // 从指定分组复制账号(创建分组后在同一事务内绑定) CopyAccountsFromGroupIDs []int64 } @@ -140,11 +160,16 @@ type UpdateGroupInput struct { WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 - ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 - FallbackGroupID *int64 // 降级分组 ID + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + // Sora 按次计费配置 + SoraImagePrice360 *float64 + SoraImagePrice540 *float64 + SoraVideoPricePerRequest *float64 + SoraVideoPricePerRequestHD *float64 + ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID // 无效请求兜底分组 ID(仅 anthropic 平台使用) FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) @@ -153,6 +178,8 @@ type UpdateGroupInput struct { MCPXMLInject *bool // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes *[]string + // Sora 存储配额 + SoraStorageQuotaBytes *int64 // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 } @@ -220,6 +247,14 @@ type BulkUpdateAccountResult struct { Error string `json:"error,omitempty"` } +// AdminUpdateAPIKeyGroupIDResult is the result of AdminUpdateAPIKeyGroupID. +type AdminUpdateAPIKeyGroupIDResult struct { + APIKey *APIKey + AutoGrantedGroupAccess bool // true if a new exclusive group permission was auto-added + GrantedGroupID *int64 // the group ID that was auto-granted + GrantedGroupName string // the group name that was auto-granted +} + // BulkUpdateAccountsResult is the aggregated response for bulk updates. type BulkUpdateAccountsResult struct { Success int `json:"success"` @@ -278,6 +313,32 @@ type ProxyTestResult struct { CountryCode string `json:"country_code,omitempty"` } +type ProxyQualityCheckResult struct { + ProxyID int64 `json:"proxy_id"` + Score int `json:"score"` + Grade string `json:"grade"` + Summary string `json:"summary"` + ExitIP string `json:"exit_ip,omitempty"` + Country string `json:"country,omitempty"` + CountryCode string `json:"country_code,omitempty"` + BaseLatencyMs int64 `json:"base_latency_ms,omitempty"` + PassedCount int `json:"passed_count"` + WarnCount int `json:"warn_count"` + FailedCount int `json:"failed_count"` + ChallengeCount int `json:"challenge_count"` + CheckedAt int64 `json:"checked_at"` + Items []ProxyQualityCheckItem `json:"items"` +} + +type ProxyQualityCheckItem struct { + Target string `json:"target"` + Status string `json:"status"` // pass/warn/fail/challenge + HTTPStatus int `json:"http_status,omitempty"` + LatencyMs int64 `json:"latency_ms,omitempty"` + Message string `json:"message,omitempty"` + CFRay string `json:"cf_ray,omitempty"` +} + // ProxyExitInfo represents proxy exit information from ip-api.com type ProxyExitInfo struct { IP string @@ -292,11 +353,64 @@ type ProxyExitInfoProber interface { ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error) } +type proxyQualityTarget struct { + Target string + URL string + Method string + AllowedStatuses map[int]struct{} +} + +var proxyQualityTargets = []proxyQualityTarget{ + { + Target: "openai", + URL: "https://api.openai.com/v1/models", + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + }, + }, + { + Target: "anthropic", + URL: "https://api.anthropic.com/v1/messages", + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + http.StatusMethodNotAllowed: {}, + http.StatusNotFound: {}, + http.StatusBadRequest: {}, + }, + }, + { + Target: "gemini", + URL: "https://generativelanguage.googleapis.com/$discovery/rest?version=v1beta", + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusOK: {}, + }, + }, + { + Target: "sora", + URL: "https://sora.chatgpt.com/backend/me", + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + }, + }, +} + +const ( + proxyQualityRequestTimeout = 15 * time.Second + proxyQualityResponseHeaderTimeout = 10 * time.Second + proxyQualityMaxBodyBytes = int64(8 * 1024) + proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36" +) + // adminServiceImpl implements AdminService type adminServiceImpl struct { userRepo UserRepository groupRepo GroupRepository accountRepo AccountRepository + soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储 proxyRepo ProxyRepository apiKeyRepo APIKeyRepository redeemCodeRepo RedeemCodeRepository @@ -305,6 +419,17 @@ type adminServiceImpl struct { proxyProber ProxyExitInfoProber proxyLatencyCache ProxyLatencyCache authCacheInvalidator APIKeyAuthCacheInvalidator + entClient *dbent.Client // 用于开启数据库事务 + settingService *SettingService + defaultSubAssigner DefaultSubscriptionAssigner +} + +type userGroupRateBatchReader interface { + GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) +} + +type groupExistenceBatchReader interface { + ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) } // NewAdminService creates a new AdminService @@ -312,6 +437,7 @@ func NewAdminService( userRepo UserRepository, groupRepo GroupRepository, accountRepo AccountRepository, + soraAccountRepo SoraAccountRepository, proxyRepo ProxyRepository, apiKeyRepo APIKeyRepository, redeemCodeRepo RedeemCodeRepository, @@ -320,11 +446,15 @@ func NewAdminService( proxyProber ProxyExitInfoProber, proxyLatencyCache ProxyLatencyCache, authCacheInvalidator APIKeyAuthCacheInvalidator, + entClient *dbent.Client, + settingService *SettingService, + defaultSubAssigner DefaultSubscriptionAssigner, ) AdminService { return &adminServiceImpl{ userRepo: userRepo, groupRepo: groupRepo, accountRepo: accountRepo, + soraAccountRepo: soraAccountRepo, proxyRepo: proxyRepo, apiKeyRepo: apiKeyRepo, redeemCodeRepo: redeemCodeRepo, @@ -333,6 +463,9 @@ func NewAdminService( proxyProber: proxyProber, proxyLatencyCache: proxyLatencyCache, authCacheInvalidator: authCacheInvalidator, + entClient: entClient, + settingService: settingService, + defaultSubAssigner: defaultSubAssigner, } } @@ -345,18 +478,43 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi } // 批量加载用户专属分组倍率 if s.userGroupRateRepo != nil && len(users) > 0 { - for i := range users { - rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID) - if err != nil { - log.Printf("failed to load user group rates: user_id=%d err=%v", users[i].ID, err) - continue + if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok { + userIDs := make([]int64, 0, len(users)) + for i := range users { + userIDs = append(userIDs, users[i].ID) } - users[i].GroupRates = rates + ratesByUser, err := batchRepo.GetByUserIDs(ctx, userIDs) + if err != nil { + logger.LegacyPrintf("service.admin", "failed to load user group rates in batch: err=%v", err) + s.loadUserGroupRatesOneByOne(ctx, users) + } else { + for i := range users { + if rates, ok := ratesByUser[users[i].ID]; ok { + users[i].GroupRates = rates + } + } + } + } else { + s.loadUserGroupRatesOneByOne(ctx, users) } } return users, result.Total, nil } +func (s *adminServiceImpl) loadUserGroupRatesOneByOne(ctx context.Context, users []User) { + if s.userGroupRateRepo == nil { + return + } + for i := range users { + rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID) + if err != nil { + logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err) + continue + } + users[i].GroupRates = rates + } +} + func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) { user, err := s.userRepo.GetByID(ctx, id) if err != nil { @@ -366,7 +524,7 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) if s.userGroupRateRepo != nil { rates, err := s.userGroupRateRepo.GetByUserID(ctx, id) if err != nil { - log.Printf("failed to load user group rates: user_id=%d err=%v", id, err) + logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", id, err) } else { user.GroupRates = rates } @@ -376,14 +534,15 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { user := &User{ - Email: input.Email, - Username: input.Username, - Notes: input.Notes, - Role: RoleUser, // Always create as regular user, never admin - Balance: input.Balance, - Concurrency: input.Concurrency, - Status: StatusActive, - AllowedGroups: input.AllowedGroups, + Email: input.Email, + Username: input.Username, + Notes: input.Notes, + Role: RoleUser, // Always create as regular user, never admin + Balance: input.Balance, + Concurrency: input.Concurrency, + Status: StatusActive, + AllowedGroups: input.AllowedGroups, + SoraStorageQuotaBytes: input.SoraStorageQuotaBytes, } if err := user.SetPassword(input.Password); err != nil { return nil, err @@ -391,9 +550,27 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu if err := s.userRepo.Create(ctx, user); err != nil { return nil, err } + s.assignDefaultSubscriptions(ctx, user.ID) return user, nil } +func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userID int64) { + if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { + return + } + items := s.settingService.GetDefaultSubscriptions(ctx) + for _, item := range items { + if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ + UserID: userID, + GroupID: item.GroupID, + ValidityDays: item.ValidityDays, + Notes: "auto assigned by default user subscriptions setting", + }); err != nil { + logger.LegacyPrintf("service.admin", "failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err) + } + } +} + func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) { user, err := s.userRepo.GetByID(ctx, id) if err != nil { @@ -437,6 +614,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda user.AllowedGroups = *input.AllowedGroups } + if input.SoraStorageQuotaBytes != nil { + user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes + } + if err := s.userRepo.Update(ctx, user); err != nil { return nil, err } @@ -444,7 +625,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda // 同步用户专属分组倍率 if input.GroupRates != nil && s.userGroupRateRepo != nil { if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil { - log.Printf("failed to sync user group rates: user_id=%d err=%v", user.ID, err) + logger.LegacyPrintf("service.admin", "failed to sync user group rates: user_id=%d err=%v", user.ID, err) } } @@ -458,7 +639,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda if concurrencyDiff != 0 { code, err := GenerateRedeemCode() if err != nil { - log.Printf("failed to generate adjustment redeem code: %v", err) + logger.LegacyPrintf("service.admin", "failed to generate adjustment redeem code: %v", err) return user, nil } adjustmentRecord := &RedeemCode{ @@ -471,7 +652,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda now := time.Now() adjustmentRecord.UsedAt = &now if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { - log.Printf("failed to create concurrency adjustment redeem code: %v", err) + logger.LegacyPrintf("service.admin", "failed to create concurrency adjustment redeem code: %v", err) } } @@ -488,7 +669,7 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error { return errors.New("cannot delete admin user") } if err := s.userRepo.Delete(ctx, id); err != nil { - log.Printf("delete user failed: user_id=%d err=%v", id, err) + logger.LegacyPrintf("service.admin", "delete user failed: user_id=%d err=%v", id, err) return err } if s.authCacheInvalidator != nil { @@ -531,7 +712,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil { - log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err) + logger.LegacyPrintf("service.admin", "invalidate user balance cache failed: user_id=%d err=%v", userID, err) } }() } @@ -539,7 +720,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, if balanceDiff != 0 { code, err := GenerateRedeemCode() if err != nil { - log.Printf("failed to generate adjustment redeem code: %v", err) + logger.LegacyPrintf("service.admin", "failed to generate adjustment redeem code: %v", err) return user, nil } @@ -555,7 +736,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, adjustmentRecord.UsedAt = &now if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { - log.Printf("failed to create balance adjustment redeem code: %v", err) + logger.LegacyPrintf("service.admin", "failed to create balance adjustment redeem code: %v", err) } } @@ -564,7 +745,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} - keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) + keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{}) if err != nil { return nil, 0, err } @@ -639,6 +820,10 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn imagePrice1K := normalizePrice(input.ImagePrice1K) imagePrice2K := normalizePrice(input.ImagePrice2K) imagePrice4K := normalizePrice(input.ImagePrice4K) + soraImagePrice360 := normalizePrice(input.SoraImagePrice360) + soraImagePrice540 := normalizePrice(input.SoraImagePrice540) + soraVideoPrice := normalizePrice(input.SoraVideoPricePerRequest) + soraVideoPriceHD := normalizePrice(input.SoraVideoPricePerRequestHD) // 校验降级分组 if input.FallbackGroupID != nil { @@ -709,12 +894,17 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ImagePrice1K: imagePrice1K, ImagePrice2K: imagePrice2K, ImagePrice4K: imagePrice4K, + SoraImagePrice360: soraImagePrice360, + SoraImagePrice540: soraImagePrice540, + SoraVideoPricePerRequest: soraVideoPrice, + SoraVideoPricePerRequestHD: soraVideoPriceHD, ClaudeCodeOnly: input.ClaudeCodeOnly, FallbackGroupID: input.FallbackGroupID, FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest, ModelRouting: input.ModelRouting, MCPXMLInject: mcpXMLInject, SupportedModelScopes: input.SupportedModelScopes, + SoraStorageQuotaBytes: input.SoraStorageQuotaBytes, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err @@ -865,6 +1055,21 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.ImagePrice4K != nil { group.ImagePrice4K = normalizePrice(input.ImagePrice4K) } + if input.SoraImagePrice360 != nil { + group.SoraImagePrice360 = normalizePrice(input.SoraImagePrice360) + } + if input.SoraImagePrice540 != nil { + group.SoraImagePrice540 = normalizePrice(input.SoraImagePrice540) + } + if input.SoraVideoPricePerRequest != nil { + group.SoraVideoPricePerRequest = normalizePrice(input.SoraVideoPricePerRequest) + } + if input.SoraVideoPricePerRequestHD != nil { + group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD) + } + if input.SoraStorageQuotaBytes != nil { + group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes + } // Claude Code 客户端限制 if input.ClaudeCodeOnly != nil { @@ -993,7 +1198,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { defer cancel() for _, userID := range affectedUserIDs { if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil { - log.Printf("invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err) + logger.LegacyPrintf("service.admin", "invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err) } } }() @@ -1020,6 +1225,103 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates [] return s.groupRepo.UpdateSortOrders(ctx, updates) } +// AdminUpdateAPIKeyGroupID 管理员修改 API Key 分组绑定 +// groupID: nil=不修改, 指向0=解绑, 指向正整数=绑定到目标分组 +func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) { + apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID) + if err != nil { + return nil, err + } + + if groupID == nil { + // nil 表示不修改,直接返回 + return &AdminUpdateAPIKeyGroupIDResult{APIKey: apiKey}, nil + } + + if *groupID < 0 { + return nil, infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative") + } + + result := &AdminUpdateAPIKeyGroupIDResult{} + + if *groupID == 0 { + // 0 表示解绑分组(不修改 user_allowed_groups,避免影响用户其他 Key) + apiKey.GroupID = nil + apiKey.Group = nil + } else { + // 验证目标分组存在且状态为 active + group, err := s.groupRepo.GetByID(ctx, *groupID) + if err != nil { + return nil, err + } + if group.Status != StatusActive { + return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active") + } + // 订阅类型分组:不允许通过此 API 直接绑定,需通过订阅管理流程 + if group.IsSubscriptionType() { + return nil, infraerrors.BadRequest("SUBSCRIPTION_GROUP_NOT_ALLOWED", "subscription groups must be managed through the subscription workflow") + } + + gid := *groupID + apiKey.GroupID = &gid + apiKey.Group = group + + // 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性 + if group.IsExclusive { + opCtx := ctx + var tx *dbent.Tx + if s.entClient == nil { + logger.LegacyPrintf("service.admin", "Warning: entClient is nil, skipping transaction protection for exclusive group binding") + } else { + var txErr error + tx, txErr = s.entClient.Tx(ctx) + if txErr != nil { + return nil, fmt.Errorf("begin transaction: %w", txErr) + } + defer func() { _ = tx.Rollback() }() + opCtx = dbent.NewTxContext(ctx, tx) + } + + if addErr := s.userRepo.AddGroupToAllowedGroups(opCtx, apiKey.UserID, gid); addErr != nil { + return nil, fmt.Errorf("add group to user allowed groups: %w", addErr) + } + if err := s.apiKeyRepo.Update(opCtx, apiKey); err != nil { + return nil, fmt.Errorf("update api key: %w", err) + } + if tx != nil { + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("commit transaction: %w", err) + } + } + + result.AutoGrantedGroupAccess = true + result.GrantedGroupID = &gid + result.GrantedGroupName = group.Name + + // 失效认证缓存(在事务提交后执行) + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key) + } + + result.APIKey = apiKey + return result, nil + } + } + + // 非专属分组 / 解绑:无需事务,单步更新即可 + if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { + return nil, fmt.Errorf("update api key: %w", err) + } + + // 失效认证缓存 + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key) + } + + result.APIKey = apiKey + return result, nil +} + // Account management implementations func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} @@ -1071,6 +1373,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou } } + // Sora apikey 账号的 base_url 必填校验 + if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey { + baseURL, _ := input.Credentials["base_url"].(string) + baseURL = strings.TrimSpace(baseURL) + if baseURL == "" { + return nil, errors.New("sora apikey 账号必须设置 base_url") + } + if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { + return nil, errors.New("base_url 必须以 http:// 或 https:// 开头") + } + } + account := &Account{ Name: input.Name, Notes: normalizeAccountNotes(input.Notes), @@ -1103,6 +1417,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou return nil, err } + // 如果是 Sora 平台账号,自动创建 sora_accounts 扩展表记录 + if account.Platform == PlatformSora && s.soraAccountRepo != nil { + soraUpdates := map[string]any{ + "access_token": account.GetCredential("access_token"), + "refresh_token": account.GetCredential("refresh_token"), + } + if err := s.soraAccountRepo.Upsert(ctx, account.ID, soraUpdates); err != nil { + // 只记录警告日志,不阻塞账号创建 + logger.LegacyPrintf("service.admin", "[AdminService] 创建 sora_accounts 记录失败: account_id=%d err=%v", account.ID, err) + } + } + // 绑定分组 if len(groupIDs) > 0 { if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil { @@ -1172,12 +1498,22 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U account.AutoPauseOnExpired = *input.AutoPauseOnExpired } + // Sora apikey 账号的 base_url 必填校验 + if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey { + baseURL, _ := account.Credentials["base_url"].(string) + baseURL = strings.TrimSpace(baseURL) + if baseURL == "" { + return nil, errors.New("sora apikey 账号必须设置 base_url") + } + if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { + return nil, errors.New("base_url 必须以 http:// 或 https:// 开头") + } + } + // 先验证分组是否存在(在任何写操作之前) if input.GroupIDs != nil { - for _, groupID := range *input.GroupIDs { - if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil { - return nil, fmt.Errorf("get group: %w", err) - } + if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil { + return nil, err } // 检查混合渠道风险(除非用户已确认) @@ -1200,7 +1536,11 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U } // 重新查询以确保返回完整数据(包括正确的 Proxy 关联对象) - return s.accountRepo.GetByID(ctx, id) + updated, err := s.accountRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + return updated, nil } // BulkUpdateAccounts updates multiple accounts in one request. @@ -1215,10 +1555,17 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp if len(input.AccountIDs) == 0 { return result, nil } + if input.GroupIDs != nil { + if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil { + return nil, err + } + } - // Preload account platforms for mixed channel risk checks if group bindings are requested. + needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck + + // 预加载账号平台信息(混合渠道检查需要)。 platformByID := map[int64]string{} - if input.GroupIDs != nil && !input.SkipMixedChannelCheck { + if needMixedChannelCheck { accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs) if err != nil { return nil, err @@ -1230,6 +1577,19 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp } } + // 预检查混合渠道风险:在任何写操作之前,若发现风险立即返回错误。 + if needMixedChannelCheck { + for _, accountID := range input.AccountIDs { + platform := platformByID[accountID] + if platform == "" { + continue + } + if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil { + return nil, err + } + } + } + if input.RateMultiplier != nil { if *input.RateMultiplier < 0 { return nil, errors.New("rate_multiplier must be >= 0") @@ -1273,31 +1633,6 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp entry := BulkUpdateAccountResult{AccountID: accountID} if input.GroupIDs != nil { - // 检查混合渠道风险(除非用户已确认) - if !input.SkipMixedChannelCheck { - platform := platformByID[accountID] - if platform == "" { - account, err := s.accountRepo.GetByID(ctx, accountID) - if err != nil { - entry.Success = false - entry.Error = err.Error() - result.Failed++ - result.FailedIDs = append(result.FailedIDs, accountID) - result.Results = append(result.Results, entry) - continue - } - platform = account.Platform - } - if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil { - entry.Success = false - entry.Error = err.Error() - result.Failed++ - result.FailedIDs = append(result.FailedIDs, accountID) - result.Results = append(result.Results, entry) - continue - } - } - if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil { entry.Success = false entry.Error = err.Error() @@ -1318,7 +1653,10 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp } func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error { - return s.accountRepo.Delete(ctx, id) + if err := s.accountRepo.Delete(ctx, id); err != nil { + return err + } + return nil } func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) { @@ -1351,7 +1689,11 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil { return nil, err } - return s.accountRepo.GetByID(ctx, id) + updated, err := s.accountRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + return updated, nil } // Proxy management implementations @@ -1629,6 +1971,269 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR }, nil } +func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) { + proxy, err := s.proxyRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + + result := &ProxyQualityCheckResult{ + ProxyID: id, + Score: 100, + Grade: "A", + CheckedAt: time.Now().Unix(), + Items: make([]ProxyQualityCheckItem, 0, len(proxyQualityTargets)+1), + } + + proxyURL := proxy.URL() + if s.proxyProber == nil { + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "base_connectivity", + Status: "fail", + Message: "代理探测服务未配置", + }) + result.FailedCount++ + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, nil) + return result, nil + } + + exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL) + if err != nil { + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "base_connectivity", + Status: "fail", + LatencyMs: latencyMs, + Message: err.Error(), + }) + result.FailedCount++ + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, nil) + return result, nil + } + + result.ExitIP = exitInfo.IP + result.Country = exitInfo.Country + result.CountryCode = exitInfo.CountryCode + result.BaseLatencyMs = latencyMs + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "base_connectivity", + Status: "pass", + LatencyMs: latencyMs, + Message: "代理出口连通正常", + }) + result.PassedCount++ + + client, err := httpclient.GetClient(httpclient.Options{ + ProxyURL: proxyURL, + Timeout: proxyQualityRequestTimeout, + ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout, + }) + if err != nil { + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "http_client", + Status: "fail", + Message: fmt.Sprintf("创建检测客户端失败: %v", err), + }) + result.FailedCount++ + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, exitInfo) + return result, nil + } + + for _, target := range proxyQualityTargets { + item := runProxyQualityTarget(ctx, client, target) + result.Items = append(result.Items, item) + switch item.Status { + case "pass": + result.PassedCount++ + case "warn": + result.WarnCount++ + case "challenge": + result.ChallengeCount++ + default: + result.FailedCount++ + } + } + + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, exitInfo) + return result, nil +} + +func runProxyQualityTarget(ctx context.Context, client *http.Client, target proxyQualityTarget) ProxyQualityCheckItem { + item := ProxyQualityCheckItem{ + Target: target.Target, + } + + req, err := http.NewRequestWithContext(ctx, target.Method, target.URL, nil) + if err != nil { + item.Status = "fail" + item.Message = fmt.Sprintf("构建请求失败: %v", err) + return item + } + req.Header.Set("Accept", "application/json,text/html,*/*") + req.Header.Set("User-Agent", proxyQualityClientUserAgent) + + start := time.Now() + resp, err := client.Do(req) + if err != nil { + item.Status = "fail" + item.LatencyMs = time.Since(start).Milliseconds() + item.Message = fmt.Sprintf("请求失败: %v", err) + return item + } + defer func() { _ = resp.Body.Close() }() + item.LatencyMs = time.Since(start).Milliseconds() + item.HTTPStatus = resp.StatusCode + + body, readErr := io.ReadAll(io.LimitReader(resp.Body, proxyQualityMaxBodyBytes+1)) + if readErr != nil { + item.Status = "fail" + item.Message = fmt.Sprintf("读取响应失败: %v", readErr) + return item + } + if int64(len(body)) > proxyQualityMaxBodyBytes { + body = body[:proxyQualityMaxBodyBytes] + } + + if target.Target == "sora" && soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) { + item.Status = "challenge" + item.CFRay = soraerror.ExtractCloudflareRayID(resp.Header, body) + item.Message = "Sora 命中 Cloudflare challenge" + return item + } + + if _, ok := target.AllowedStatuses[resp.StatusCode]; ok { + if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices { + item.Status = "pass" + item.Message = fmt.Sprintf("HTTP %d", resp.StatusCode) + } else { + item.Status = "warn" + item.Message = fmt.Sprintf("HTTP %d(目标可达,但鉴权或方法受限)", resp.StatusCode) + } + return item + } + + if resp.StatusCode == http.StatusTooManyRequests { + item.Status = "warn" + item.Message = "目标返回 429,可能存在频控" + return item + } + + item.Status = "fail" + item.Message = fmt.Sprintf("非预期状态码: %d", resp.StatusCode) + return item +} + +func finalizeProxyQualityResult(result *ProxyQualityCheckResult) { + if result == nil { + return + } + score := 100 - result.WarnCount*10 - result.FailedCount*22 - result.ChallengeCount*30 + if score < 0 { + score = 0 + } + result.Score = score + result.Grade = proxyQualityGrade(score) + result.Summary = fmt.Sprintf( + "通过 %d 项,告警 %d 项,失败 %d 项,挑战 %d 项", + result.PassedCount, + result.WarnCount, + result.FailedCount, + result.ChallengeCount, + ) +} + +func proxyQualityGrade(score int) string { + switch { + case score >= 90: + return "A" + case score >= 75: + return "B" + case score >= 60: + return "C" + case score >= 40: + return "D" + default: + return "F" + } +} + +func proxyQualityOverallStatus(result *ProxyQualityCheckResult) string { + if result == nil { + return "" + } + if result.ChallengeCount > 0 { + return "challenge" + } + if result.FailedCount > 0 { + return "failed" + } + if result.WarnCount > 0 { + return "warn" + } + if result.PassedCount > 0 { + return "healthy" + } + return "failed" +} + +func proxyQualityFirstCFRay(result *ProxyQualityCheckResult) string { + if result == nil { + return "" + } + for _, item := range result.Items { + if item.CFRay != "" { + return item.CFRay + } + } + return "" +} + +func proxyQualityBaseConnectivityPass(result *ProxyQualityCheckResult) bool { + if result == nil { + return false + } + for _, item := range result.Items { + if item.Target == "base_connectivity" { + return item.Status == "pass" + } + } + return false +} + +func (s *adminServiceImpl) saveProxyQualitySnapshot(ctx context.Context, proxyID int64, result *ProxyQualityCheckResult, exitInfo *ProxyExitInfo) { + if result == nil { + return + } + score := result.Score + checkedAt := result.CheckedAt + info := &ProxyLatencyInfo{ + Success: proxyQualityBaseConnectivityPass(result), + Message: result.Summary, + QualityStatus: proxyQualityOverallStatus(result), + QualityScore: &score, + QualityGrade: result.Grade, + QualitySummary: result.Summary, + QualityCheckedAt: &checkedAt, + QualityCFRay: proxyQualityFirstCFRay(result), + UpdatedAt: time.Now(), + } + if result.BaseLatencyMs > 0 { + latency := result.BaseLatencyMs + info.LatencyMs = &latency + } + if exitInfo != nil { + info.IPAddress = exitInfo.IP + info.Country = exitInfo.Country + info.CountryCode = exitInfo.CountryCode + info.Region = exitInfo.Region + info.City = exitInfo.City + } + s.saveProxyLatency(ctx, proxyID, info) +} + func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) { if s.proxyProber == nil || proxy == nil { return @@ -1706,6 +2311,40 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc return nil } +func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error { + if len(groupIDs) == 0 { + return nil + } + if s.groupRepo == nil { + return errors.New("group repository not configured") + } + + if batchReader, ok := s.groupRepo.(groupExistenceBatchReader); ok { + existsByID, err := batchReader.ExistsByIDs(ctx, groupIDs) + if err != nil { + return fmt.Errorf("check groups exists: %w", err) + } + for _, groupID := range groupIDs { + if groupID <= 0 || !existsByID[groupID] { + return fmt.Errorf("get group: %w", ErrGroupNotFound) + } + } + return nil + } + + for _, groupID := range groupIDs { + if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil { + return fmt.Errorf("get group: %w", err) + } + } + return nil +} + +// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform. +func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error { + return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs) +} + func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) { if s.proxyLatencyCache == nil || len(proxies) == 0 { return @@ -1718,7 +2357,7 @@ func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []Pro latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, ids) if err != nil { - log.Printf("Warning: load proxy latency cache failed: %v", err) + logger.LegacyPrintf("service.admin", "Warning: load proxy latency cache failed: %v", err) return } @@ -1739,6 +2378,11 @@ func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []Pro proxies[i].CountryCode = info.CountryCode proxies[i].Region = info.Region proxies[i].City = info.City + proxies[i].QualityStatus = info.QualityStatus + proxies[i].QualityScore = info.QualityScore + proxies[i].QualityGrade = info.QualityGrade + proxies[i].QualitySummary = info.QualitySummary + proxies[i].QualityChecked = info.QualityCheckedAt } } @@ -1746,8 +2390,28 @@ func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64, if s.proxyLatencyCache == nil || info == nil { return } - if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, info); err != nil { - log.Printf("Warning: store proxy latency cache failed: %v", err) + + merged := *info + if latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, []int64{proxyID}); err == nil { + if existing := latencies[proxyID]; existing != nil { + if merged.QualityCheckedAt == nil && + merged.QualityScore == nil && + merged.QualityGrade == "" && + merged.QualityStatus == "" && + merged.QualitySummary == "" && + merged.QualityCFRay == "" { + merged.QualityStatus = existing.QualityStatus + merged.QualityScore = existing.QualityScore + merged.QualityGrade = existing.QualityGrade + merged.QualitySummary = existing.QualitySummary + merged.QualityCheckedAt = existing.QualityCheckedAt + merged.QualityCFRay = existing.QualityCFRay + } + } + } + + if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, &merged); err != nil { + logger.LegacyPrintf("service.admin", "Warning: store proxy latency cache failed: %v", err) } } diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go new file mode 100644 index 00000000..a6d12f97 --- /dev/null +++ b/backend/internal/service/admin_service_apikey_test.go @@ -0,0 +1,429 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Stubs +// --------------------------------------------------------------------------- + +// userRepoStubForGroupUpdate implements UserRepository for AdminUpdateAPIKeyGroupID tests. +type userRepoStubForGroupUpdate struct { + addGroupErr error + addGroupCalled bool + addedUserID int64 + addedGroupID int64 +} + +func (s *userRepoStubForGroupUpdate) AddGroupToAllowedGroups(_ context.Context, userID int64, groupID int64) error { + s.addGroupCalled = true + s.addedUserID = userID + s.addedGroupID = groupID + return s.addGroupErr +} + +func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") } + +// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests. +type apiKeyRepoStubForGroupUpdate struct { + key *APIKey + getErr error + updateErr error + updated *APIKey // captures what was passed to Update +} + +func (s *apiKeyRepoStubForGroupUpdate) GetByID(_ context.Context, _ int64) (*APIKey, error) { + if s.getErr != nil { + return nil, s.getErr + } + clone := *s.key + return &clone, nil +} +func (s *apiKeyRepoStubForGroupUpdate) Update(_ context.Context, key *APIKey) error { + if s.updateErr != nil { + return s.updateErr + } + clone := *key + s.updated = &clone + return nil +} + +// Unused methods – panic on unexpected call. +func (s *apiKeyRepoStubForGroupUpdate) Create(context.Context, *APIKey) error { panic("unexpected") } +func (s *apiKeyRepoStubForGroupUpdate) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) GetByKey(context.Context, string) (*APIKey, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) GetByKeyForAuth(context.Context, string) (*APIKey, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") } +func (s *apiKeyRepoStubForGroupUpdate) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) CountByUserID(context.Context, int64) (int64, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) ExistsByKey(context.Context, string) (bool, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) ListKeysByUserID(context.Context, int64) ([]string, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) ListKeysByGroupID(context.Context, int64) ([]string, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) UpdateLastUsed(context.Context, int64, time.Time) error { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) IncrementRateLimitUsage(context.Context, int64, float64) error { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) ResetRateLimitWindows(context.Context, int64) error { + panic("unexpected") +} +func (s *apiKeyRepoStubForGroupUpdate) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) { + panic("unexpected") +} + +// groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests. +type groupRepoStubForGroupUpdate struct { + group *Group + getErr error + lastGetByIDArg int64 +} + +func (s *groupRepoStubForGroupUpdate) GetByID(_ context.Context, id int64) (*Group, error) { + s.lastGetByIDArg = id + if s.getErr != nil { + return nil, s.getErr + } + clone := *s.group + return &clone, nil +} + +// Unused methods – panic on unexpected call. +func (s *groupRepoStubForGroupUpdate) Create(context.Context, *Group) error { panic("unexpected") } +func (s *groupRepoStubForGroupUpdate) GetByIDLite(context.Context, int64) (*Group, error) { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) Update(context.Context, *Group) error { panic("unexpected") } +func (s *groupRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") } +func (s *groupRepoStubForGroupUpdate) DeleteCascade(context.Context, int64) ([]int64, error) { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]Group, *pagination.PaginationResult, error) { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) ListActive(context.Context) ([]Group, error) { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, string) ([]Group, error) { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) BindAccountsToGroup(context.Context, int64, []int64) error { + panic("unexpected") +} +func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupSortOrderUpdate) error { + panic("unexpected") +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +func TestAdminService_AdminUpdateAPIKeyGroupID_KeyNotFound(t *testing.T) { + repo := &apiKeyRepoStubForGroupUpdate{getErr: ErrAPIKeyNotFound} + svc := &adminServiceImpl{apiKeyRepo: repo} + + _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 999, int64Ptr(1)) + require.ErrorIs(t, err, ErrAPIKeyNotFound) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_NilGroupID_NoOp(t *testing.T) { + existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(5)} + repo := &apiKeyRepoStubForGroupUpdate{key: existing} + svc := &adminServiceImpl{apiKeyRepo: repo} + + got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, nil) + require.NoError(t, err) + require.Equal(t, int64(1), got.APIKey.ID) + // Update should NOT have been called (updated stays nil) + require.Nil(t, repo.updated) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_Unbind(t *testing.T) { + existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(5), Group: &Group{ID: 5, Name: "Old"}} + repo := &apiKeyRepoStubForGroupUpdate{key: existing} + cache := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{apiKeyRepo: repo, authCacheInvalidator: cache} + + got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0)) + require.NoError(t, err) + require.Nil(t, got.APIKey.GroupID, "group_id should be nil after unbind") + require.Nil(t, got.APIKey.Group, "group object should be nil after unbind") + require.NotNil(t, repo.updated, "Update should have been called") + require.Nil(t, repo.updated.GroupID) + require.Equal(t, []string{"sk-test"}, cache.keys, "cache should be invalidated") +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_BindActiveGroup(t *testing.T) { + existing := &APIKey{ID: 1, Key: "sk-test", GroupID: nil} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}} + cache := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache} + + got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) + require.NoError(t, err) + require.NotNil(t, got.APIKey.GroupID) + require.Equal(t, int64(10), *got.APIKey.GroupID) + require.Equal(t, int64(10), *apiKeyRepo.updated.GroupID) + require.Equal(t, []string{"sk-test"}, cache.keys) + // M3: verify correct group ID was passed to repo + require.Equal(t, int64(10), groupRepo.lastGetByIDArg) + // C1 fix: verify Group object is populated + require.NotNil(t, got.APIKey.Group) + require.Equal(t, "Pro", got.APIKey.Group.Name) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_SameGroup_Idempotent(t *testing.T) { + existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(10), Group: &Group{ID: 10, Name: "Pro"}} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}} + cache := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache} + + got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) + require.NoError(t, err) + require.NotNil(t, got.APIKey.GroupID) + require.Equal(t, int64(10), *got.APIKey.GroupID) + // Update is still called (current impl doesn't short-circuit on same group) + require.NotNil(t, apiKeyRepo.updated) + require.Equal(t, []string{"sk-test"}, cache.keys) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_GroupNotFound(t *testing.T) { + existing := &APIKey{ID: 1, Key: "sk-test"} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{getErr: ErrGroupNotFound} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo} + + _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(99)) + require.ErrorIs(t, err, ErrGroupNotFound) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_GroupNotActive(t *testing.T) { + existing := &APIKey{ID: 1, Key: "sk-test"} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 5, Status: StatusDisabled}} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo} + + _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(5)) + require.Error(t, err) + require.Equal(t, "GROUP_NOT_ACTIVE", infraerrors.Reason(err)) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_UpdateFails(t *testing.T) { + existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(3)} + repo := &apiKeyRepoStubForGroupUpdate{key: existing, updateErr: errors.New("db write error")} + svc := &adminServiceImpl{apiKeyRepo: repo} + + _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0)) + require.Error(t, err) + require.Contains(t, err.Error(), "update api key") +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_NegativeGroupID(t *testing.T) { + existing := &APIKey{ID: 1, Key: "sk-test"} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo} + + _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(-5)) + require.Error(t, err) + require.Equal(t, "INVALID_GROUP_ID", infraerrors.Reason(err)) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_PointerIsolation(t *testing.T) { + existing := &APIKey{ID: 1, Key: "sk-test", GroupID: nil} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}} + cache := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache} + + inputGID := int64(10) + got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, &inputGID) + require.NoError(t, err) + require.NotNil(t, got.APIKey.GroupID) + // Mutating the input pointer must NOT affect the stored value + inputGID = 999 + require.Equal(t, int64(10), *got.APIKey.GroupID) + require.Equal(t, int64(10), *apiKeyRepo.updated.GroupID) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_NilCacheInvalidator(t *testing.T) { + existing := &APIKey{ID: 1, Key: "sk-test"} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 7, Status: StatusActive}} + // authCacheInvalidator is nil – should not panic + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo} + + got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(7)) + require.NoError(t, err) + require.NotNil(t, got.APIKey.GroupID) + require.Equal(t, int64(7), *got.APIKey.GroupID) +} + +// --------------------------------------------------------------------------- +// Tests: AllowedGroup auto-sync +// --------------------------------------------------------------------------- + +func TestAdminService_AdminUpdateAPIKeyGroupID_ExclusiveGroup_AddsAllowedGroup(t *testing.T) { + existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Exclusive", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeStandard}} + userRepo := &userRepoStubForGroupUpdate{} + cache := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, authCacheInvalidator: cache} + + got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) + require.NoError(t, err) + require.NotNil(t, got.APIKey.GroupID) + require.Equal(t, int64(10), *got.APIKey.GroupID) + // 验证 AddGroupToAllowedGroups 被调用,且参数正确 + require.True(t, userRepo.addGroupCalled) + require.Equal(t, int64(42), userRepo.addedUserID) + require.Equal(t, int64(10), userRepo.addedGroupID) + // 验证 result 标记了自动授权 + require.True(t, got.AutoGrantedGroupAccess) + require.NotNil(t, got.GrantedGroupID) + require.Equal(t, int64(10), *got.GrantedGroupID) + require.Equal(t, "Exclusive", got.GrantedGroupName) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupUpdate(t *testing.T) { + existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Public", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeStandard}} + userRepo := &userRepoStubForGroupUpdate{} + cache := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, authCacheInvalidator: cache} + + got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) + require.NoError(t, err) + require.NotNil(t, got.APIKey.GroupID) + // 非专属分组不触发 AddGroupToAllowedGroups + require.False(t, userRepo.addGroupCalled) + require.False(t, got.AutoGrantedGroupAccess) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) { + existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}} + userRepo := &userRepoStubForGroupUpdate{} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo} + + // 订阅类型分组应被阻止绑定 + _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) + require.Error(t, err) + require.Equal(t, "SUBSCRIPTION_GROUP_NOT_ALLOWED", infraerrors.Reason(err)) + require.False(t, userRepo.addGroupCalled) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_ExclusiveGroup_AllowedGroupAddFails_ReturnsError(t *testing.T) { + existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Exclusive", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeStandard}} + userRepo := &userRepoStubForGroupUpdate{addGroupErr: errors.New("db error")} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo} + + // 严格模式:AddGroupToAllowedGroups 失败时,整体操作报错 + _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10)) + require.Error(t, err) + require.Contains(t, err.Error(), "add group to user allowed groups") + require.True(t, userRepo.addGroupCalled) + // apiKey 不应被更新 + require.Nil(t, apiKeyRepo.updated) +} + +func TestAdminService_AdminUpdateAPIKeyGroupID_Unbind_NoAllowedGroupUpdate(t *testing.T) { + existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: int64Ptr(10), Group: &Group{ID: 10, Name: "Exclusive"}} + apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing} + userRepo := &userRepoStubForGroupUpdate{} + cache := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, userRepo: userRepo, authCacheInvalidator: cache} + + got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0)) + require.NoError(t, err) + require.Nil(t, got.APIKey.GroupID) + // 解绑时不修改 allowed_groups + require.False(t, userRepo.addGroupCalled) + require.False(t, got.AutoGrantedGroupAccess) +} diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go index 662b95fb..4845d87c 100644 --- a/backend/internal/service/admin_service_bulk_update_test.go +++ b/backend/internal/service/admin_service_bulk_update_test.go @@ -15,6 +15,16 @@ type accountRepoStubForBulkUpdate struct { bulkUpdateErr error bulkUpdateIDs []int64 bindGroupErrByID map[int64]error + bindGroupsCalls []int64 + getByIDsAccounts []*Account + getByIDsErr error + getByIDsCalled bool + getByIDsIDs []int64 + getByIDAccounts map[int64]*Account + getByIDErrByID map[int64]error + getByIDCalled []int64 + listByGroupData map[int64][]Account + listByGroupErr map[int64]error } func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) { @@ -26,12 +36,43 @@ func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64 } func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID int64, _ []int64) error { + s.bindGroupsCalls = append(s.bindGroupsCalls, accountID) if err, ok := s.bindGroupErrByID[accountID]; ok { return err } return nil } +func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) { + s.getByIDsCalled = true + s.getByIDsIDs = append([]int64{}, ids...) + if s.getByIDsErr != nil { + return nil, s.getByIDsErr + } + return s.getByIDsAccounts, nil +} + +func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Account, error) { + s.getByIDCalled = append(s.getByIDCalled, id) + if err, ok := s.getByIDErrByID[id]; ok { + return nil, err + } + if account, ok := s.getByIDAccounts[id]; ok { + return account, nil + } + return nil, errors.New("account not found") +} + +func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) { + if err, ok := s.listByGroupErr[groupID]; ok { + return nil, err + } + if rows, ok := s.listByGroupData[groupID]; ok { + return rows, nil + } + return nil, nil +} + // TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。 func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) { repo := &accountRepoStubForBulkUpdate{} @@ -59,7 +100,10 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) { 2: errors.New("bind failed"), }, } - svc := &adminServiceImpl{accountRepo: repo} + svc := &adminServiceImpl{ + accountRepo: repo, + groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "g10"}}, + } groupIDs := []int64{10} schedulable := false @@ -78,3 +122,51 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) { require.ElementsMatch(t, []int64{2}, result.FailedIDs) require.Len(t, result.Results, 3) } + +func TestAdminService_BulkUpdateAccounts_NilGroupRepoReturnsError(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{} + svc := &adminServiceImpl{accountRepo: repo} + + groupIDs := []int64{10} + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1}, + GroupIDs: &groupIDs, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "group repository not configured") +} + +// TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict verifies +// that the global pre-check detects a conflict with existing group members and returns an +// error before any DB write is performed. +func TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{ + getByIDsAccounts: []*Account{ + {ID: 1, Platform: PlatformAntigravity}, + }, + // Group 10 already contains an Anthropic account. + listByGroupData: map[int64][]Account{ + 10: {{ID: 99, Platform: PlatformAnthropic}}, + }, + } + svc := &adminServiceImpl{ + accountRepo: repo, + groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "target-group"}}, + } + + groupIDs := []int64{10} + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1}, + GroupIDs: &groupIDs, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "mixed channel") + // No BindGroups should have been called since the check runs before any write. + require.Empty(t, repo.bindGroupsCalls) +} diff --git a/backend/internal/service/admin_service_create_user_test.go b/backend/internal/service/admin_service_create_user_test.go index a0fe4d87..c5b1e38d 100644 --- a/backend/internal/service/admin_service_create_user_test.go +++ b/backend/internal/service/admin_service_create_user_test.go @@ -7,6 +7,7 @@ import ( "errors" "testing" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/stretchr/testify/require" ) @@ -65,3 +66,32 @@ func TestAdminService_CreateUser_CreateError(t *testing.T) { require.ErrorIs(t, err, createErr) require.Empty(t, repo.created) } + +func TestAdminService_CreateUser_AssignsDefaultSubscriptions(t *testing.T) { + repo := &userRepoStub{nextID: 21} + assigner := &defaultSubscriptionAssignerStub{} + cfg := &config.Config{ + Default: config.DefaultConfig{ + UserBalance: 0, + UserConcurrency: 1, + }, + } + settingService := NewSettingService(&settingRepoStub{values: map[string]string{ + SettingKeyDefaultSubscriptions: `[{"group_id":5,"validity_days":30}]`, + }}, cfg) + svc := &adminServiceImpl{ + userRepo: repo, + settingService: settingService, + defaultSubAssigner: assigner, + } + + _, err := svc.CreateUser(context.Background(), &CreateUserInput{ + Email: "new-user@test.com", + Password: "password", + }) + require.NoError(t, err) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(21), assigner.calls[0].UserID) + require.Equal(t, int64(5), assigner.calls[0].GroupID) + require.Equal(t, 30, assigner.calls[0].ValidityDays) +} diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index 60fa3d77..2e0f7d90 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -93,6 +93,10 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID panic("unexpected RemoveGroupFromAllowedGroups call") } +func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { + panic("unexpected AddGroupToAllowedGroups call") +} + func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { panic("unexpected UpdateTotpSecret call") } @@ -344,6 +348,19 @@ func (s *billingCacheStub) InvalidateSubscriptionCache(ctx context.Context, user return nil } +func (s *billingCacheStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) { + panic("unexpected GetAPIKeyRateLimit call") +} +func (s *billingCacheStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error { + panic("unexpected SetAPIKeyRateLimit call") +} +func (s *billingCacheStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error { + panic("unexpected UpdateAPIKeyRateLimitUsage call") +} +func (s *billingCacheStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error { + panic("unexpected InvalidateAPIKeyRateLimit call") +} + func waitForInvalidations(t *testing.T, ch <-chan subscriptionInvalidateCall, expected int) []subscriptionInvalidateCall { t.Helper() calls := make([]subscriptionInvalidateCall, 0, expected) diff --git a/backend/internal/service/admin_service_list_users_test.go b/backend/internal/service/admin_service_list_users_test.go new file mode 100644 index 00000000..8b50530a --- /dev/null +++ b/backend/internal/service/admin_service_list_users_test.go @@ -0,0 +1,106 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type userRepoStubForListUsers struct { + userRepoStub + users []User + err error +} + +func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) { + if s.err != nil { + return nil, nil, s.err + } + out := make([]User, len(s.users)) + copy(out, s.users) + return out, &pagination.PaginationResult{ + Total: int64(len(out)), + Page: params.Page, + PageSize: params.PageSize, + }, nil +} + +type userGroupRateRepoStubForListUsers struct { + batchCalls int + singleCall []int64 + + batchErr error + batchData map[int64]map[int64]float64 + + singleErr map[int64]error + singleData map[int64]map[int64]float64 +} + +func (s *userGroupRateRepoStubForListUsers) GetByUserIDs(_ context.Context, _ []int64) (map[int64]map[int64]float64, error) { + s.batchCalls++ + if s.batchErr != nil { + return nil, s.batchErr + } + return s.batchData, nil +} + +func (s *userGroupRateRepoStubForListUsers) GetByUserID(_ context.Context, userID int64) (map[int64]float64, error) { + s.singleCall = append(s.singleCall, userID) + if err, ok := s.singleErr[userID]; ok { + return nil, err + } + if rates, ok := s.singleData[userID]; ok { + return rates, nil + } + return map[int64]float64{}, nil +} + +func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context, userID, groupID int64) (*float64, error) { + panic("unexpected GetByUserAndGroup call") +} + +func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error { + panic("unexpected SyncUserGroupRates call") +} + +func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error { + panic("unexpected DeleteByGroupID call") +} + +func (s *userGroupRateRepoStubForListUsers) DeleteByUserID(_ context.Context, userID int64) error { + panic("unexpected DeleteByUserID call") +} + +func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) { + userRepo := &userRepoStubForListUsers{ + users: []User{ + {ID: 101, Username: "u1"}, + {ID: 202, Username: "u2"}, + }, + } + rateRepo := &userGroupRateRepoStubForListUsers{ + batchErr: errors.New("batch unavailable"), + singleData: map[int64]map[int64]float64{ + 101: {11: 1.1}, + 202: {22: 2.2}, + }, + } + svc := &adminServiceImpl{ + userRepo: userRepo, + userGroupRateRepo: rateRepo, + } + + users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}) + require.NoError(t, err) + require.Equal(t, int64(2), total) + require.Len(t, users, 2) + require.Equal(t, 1, rateRepo.batchCalls) + require.ElementsMatch(t, []int64{101, 202}, rateRepo.singleCall) + require.Equal(t, 1.1, users[0].GroupRates[11]) + require.Equal(t, 2.2, users[1].GroupRates[22]) +} diff --git a/backend/internal/service/admin_service_proxy_quality_test.go b/backend/internal/service/admin_service_proxy_quality_test.go new file mode 100644 index 00000000..5a43cd9c --- /dev/null +++ b/backend/internal/service/admin_service_proxy_quality_test.go @@ -0,0 +1,95 @@ +package service + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFinalizeProxyQualityResult_ScoreAndGrade(t *testing.T) { + result := &ProxyQualityCheckResult{ + PassedCount: 2, + WarnCount: 1, + FailedCount: 1, + ChallengeCount: 1, + } + + finalizeProxyQualityResult(result) + + require.Equal(t, 38, result.Score) + require.Equal(t, "F", result.Grade) + require.Contains(t, result.Summary, "通过 2 项") + require.Contains(t, result.Summary, "告警 1 项") + require.Contains(t, result.Summary, "失败 1 项") + require.Contains(t, result.Summary, "挑战 1 项") +} + +func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.Header().Set("cf-ray", "test-ray-123") + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte("Just a moment...")) + })) + defer server.Close() + + target := proxyQualityTarget{ + Target: "sora", + URL: server.URL, + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + }, + } + + item := runProxyQualityTarget(context.Background(), server.Client(), target) + require.Equal(t, "challenge", item.Status) + require.Equal(t, http.StatusForbidden, item.HTTPStatus) + require.Equal(t, "test-ray-123", item.CFRay) +} + +func TestRunProxyQualityTarget_AllowedStatusPass(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"models":[]}`)) + })) + defer server.Close() + + target := proxyQualityTarget{ + Target: "gemini", + URL: server.URL, + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusOK: {}, + }, + } + + item := runProxyQualityTarget(context.Background(), server.Client(), target) + require.Equal(t, "pass", item.Status) + require.Equal(t, http.StatusOK, item.HTTPStatus) +} + +func TestRunProxyQualityTarget_AllowedStatusWarnForUnauthorized(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) + })) + defer server.Close() + + target := proxyQualityTarget{ + Target: "openai", + URL: server.URL, + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + }, + } + + item := runProxyQualityTarget(context.Background(), server.Client(), target) + require.Equal(t, "warn", item.Status) + require.Equal(t, http.StatusUnauthorized, item.HTTPStatus) + require.Contains(t, item.Message, "目标可达") +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 9e7ed7d5..96ff3354 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -21,9 +21,10 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" - "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/gin-gonic/gin" "github.com/google/uuid" + "github.com/tidwall/gjson" ) const ( @@ -85,7 +86,6 @@ var ( ) const ( - antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL" antigravityForwardBaseURLEnv = "GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL" antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS" ) @@ -184,7 +184,7 @@ type smartRetryResult struct { func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParams, resp *http.Response, respBody []byte, baseURL string, urlIdx int, availableURLs []string) *smartRetryResult { // "Resource has been exhausted" 是 URL 级别限流,切换 URL(仅 429) if resp.StatusCode == http.StatusTooManyRequests && isURLLevelRateLimit(respBody) && urlIdx < len(availableURLs)-1 { - log.Printf("%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1]) + logger.LegacyPrintf("service.antigravity_gateway", "%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1]) return &smartRetryResult{action: smartRetryActionContinueURL} } @@ -204,13 +204,13 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam if rateLimitDuration <= 0 { rateLimitDuration = antigravityDefaultRateLimitDuration } - log.Printf("%s status=%d oauth_long_delay model=%s account=%d upstream_retry_delay=%v body=%s (model rate limit, switch account)", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d oauth_long_delay model=%s account=%d upstream_retry_delay=%v body=%s (model rate limit, switch account)", p.prefix, resp.StatusCode, modelName, p.account.ID, rateLimitDuration, truncateForLog(respBody, 200)) resetAt := time.Now().Add(rateLimitDuration) if !setModelRateLimitByModelName(p.ctx, p.accountRepo, p.account.ID, modelName, p.prefix, resp.StatusCode, resetAt, false) { p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) - log.Printf("%s status=%d rate_limited account=%d (no model mapping)", p.prefix, resp.StatusCode, p.account.ID) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d rate_limited account=%d (no model mapping)", p.prefix, resp.StatusCode, p.account.ID) } else { s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt) } @@ -273,7 +273,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam // 智能重试:创建新请求 retryReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body) if err != nil { - log.Printf("%s status=smart_retry_request_build_failed error=%v", p.prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=smart_retry_request_build_failed error=%v", p.prefix, err) p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) return &smartRetryResult{ action: smartRetryActionBreakWithResp, @@ -356,7 +356,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam // 单账号 503 退避重试模式:智能重试耗尽后不设限流、不切换账号, // 直接返回 503 让 Handler 层的单账号退避循环做最终处理。 if resp.StatusCode == http.StatusServiceUnavailable && isSingleAccountRetry(p.ctx) { - log.Printf("%s status=%d smart_retry_exhausted_single_account attempts=%d model=%s account=%d body=%s (return 503 directly)", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d smart_retry_exhausted_single_account attempts=%d model=%s account=%d body=%s (return 503 directly)", p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID, truncateForLog(retryBody, 200)) return &smartRetryResult{ action: smartRetryActionBreakWithResp, @@ -374,9 +374,9 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam resetAt := time.Now().Add(rateLimitDuration) if p.accountRepo != nil && modelName != "" { if err := p.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, modelName, resetAt); err != nil { - log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", p.prefix, resp.StatusCode, modelName, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limit_failed model=%s error=%v", p.prefix, resp.StatusCode, modelName, err) } else { - log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", p.prefix, resp.StatusCode, modelName, p.account.ID, rateLimitDuration) s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt) } @@ -431,7 +431,7 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace( waitDuration = antigravitySmartRetryMinWait } - log.Printf("%s status=%d single_account_503_retry_in_place model=%s account=%d upstream_retry_delay=%v (retrying in-place instead of rate-limiting)", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry_in_place model=%s account=%d upstream_retry_delay=%v (retrying in-place instead of rate-limiting)", p.prefix, resp.StatusCode, modelName, p.account.ID, waitDuration) var lastRetryResp *http.Response @@ -443,21 +443,21 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace( if totalWaited+waitDuration > antigravitySingleAccountSmartRetryTotalMaxWait { remaining := antigravitySingleAccountSmartRetryTotalMaxWait - totalWaited if remaining <= 0 { - log.Printf("%s single_account_503_retry: total_wait_exceeded total=%v max=%v, giving up", + logger.LegacyPrintf("service.antigravity_gateway", "%s single_account_503_retry: total_wait_exceeded total=%v max=%v, giving up", p.prefix, totalWaited, antigravitySingleAccountSmartRetryTotalMaxWait) break } waitDuration = remaining } - log.Printf("%s status=%d single_account_503_retry attempt=%d/%d delay=%v total_waited=%v model=%s account=%d", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry attempt=%d/%d delay=%v total_waited=%v model=%s account=%d", p.prefix, resp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, waitDuration, totalWaited, modelName, p.account.ID) timer := time.NewTimer(waitDuration) select { case <-p.ctx.Done(): timer.Stop() - log.Printf("%s status=context_canceled_during_single_account_retry", p.prefix) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_single_account_retry", p.prefix) return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()} case <-timer.C: } @@ -466,13 +466,13 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace( // 创建新请求 retryReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body) if err != nil { - log.Printf("%s single_account_503_retry: request_build_failed error=%v", p.prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s single_account_503_retry: request_build_failed error=%v", p.prefix, err) break } retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency) if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable { - log.Printf("%s status=%d single_account_503_retry_success attempt=%d/%d total_waited=%v", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry_success attempt=%d/%d total_waited=%v", p.prefix, retryResp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, totalWaited) // 关闭之前的响应 if lastRetryResp != nil { @@ -483,7 +483,7 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace( // 网络错误时继续重试 if retryErr != nil || retryResp == nil { - log.Printf("%s single_account_503_retry: network_error attempt=%d/%d error=%v", + logger.LegacyPrintf("service.antigravity_gateway", "%s single_account_503_retry: network_error attempt=%d/%d error=%v", p.prefix, attempt, antigravitySingleAccountSmartRetryMaxAttempts, retryErr) continue } @@ -517,7 +517,7 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace( if retryBody == nil { retryBody = respBody } - log.Printf("%s status=%d single_account_503_retry_exhausted attempts=%d total_waited=%v model=%s account=%d body=%s (return 503 directly)", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry_exhausted attempts=%d total_waited=%v model=%s account=%d body=%s (return 503 directly)", p.prefix, resp.StatusCode, antigravitySingleAccountSmartRetryMaxAttempts, totalWaited, modelName, p.account.ID, truncateForLog(retryBody, 200)) return &smartRetryResult{ @@ -540,10 +540,10 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP // 如果上游确实还不可用,handleSmartRetry → handleSingleAccountRetryInPlace // 会在 Service 层原地等待+重试,不需要在预检查这里等。 if isSingleAccountRetry(p.ctx) { - log.Printf("%s pre_check: single_account_retry skipping rate_limit remaining=%v model=%s account=%d (will retry in-place if 503)", + logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: single_account_retry skipping rate_limit remaining=%v model=%s account=%d (will retry in-place if 503)", p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID) } else { - log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d", + logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: rate_limit_switch remaining=%v model=%s account=%d", p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID) return nil, &AntigravityAccountSwitchError{ OriginalAccountID: p.account.ID, @@ -580,7 +580,7 @@ urlFallbackLoop: for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { select { case <-p.ctx.Done(): - log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err()) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled error=%v", p.prefix, p.ctx.Err()) return nil, p.ctx.Err() default: } @@ -610,18 +610,18 @@ urlFallbackLoop: Message: safeErr, }) if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { - log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1]) + logger.LegacyPrintf("service.antigravity_gateway", "%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1]) continue urlFallbackLoop } if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err) if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", p.prefix) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_backoff", p.prefix) return nil, p.ctx.Err() } continue } - log.Printf("%s status=request_failed retries_exhausted error=%v", p.prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=request_failed retries_exhausted error=%v", p.prefix, err) setOpsUpstreamError(p.c, 0, safeErr, "") return nil, fmt.Errorf("upstream request failed after retries: %w", err) } @@ -678,9 +678,9 @@ urlFallbackLoop: Message: upstreamMsg, Detail: getUpstreamDetail(respBody), }) - log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", p.prefix) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_backoff", p.prefix) return nil, p.ctx.Err() } continue @@ -688,7 +688,7 @@ urlFallbackLoop: // 重试用尽,标记账户限流 p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) - log.Printf("%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200)) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200)) resp = &http.Response{ StatusCode: resp.StatusCode, Header: resp.Header.Clone(), @@ -712,9 +712,9 @@ urlFallbackLoop: Message: upstreamMsg, Detail: getUpstreamDetail(respBody), }) - log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", p.prefix) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_backoff", p.prefix) return nil, p.ctx.Err() } continue @@ -1012,14 +1012,14 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account } // 调试日志:Test 请求信息 - log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String()) + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String()) // 发送请求 resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) if err != nil { lastErr = fmt.Errorf("请求失败: %w", err) if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { - log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) continue } return nil, lastErr @@ -1034,7 +1034,7 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account // 检查是否需要 URL 降级 if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { - log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) continue } @@ -1243,16 +1243,12 @@ func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model strin } // unwrapV1InternalResponse 解包 v1internal 响应 +// 使用 gjson 零拷贝提取 response 字段,避免 Unmarshal+Marshal 双重开销 func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byte, error) { - var outer map[string]any - if err := json.Unmarshal(body, &outer); err != nil { - return nil, err + result := gjson.GetBytes(body, "response") + if result.Exists() { + return []byte(result.Raw), nil } - - if resp, ok := outer["response"]; ok { - return json.Marshal(resp) - } - return body, nil } @@ -1311,6 +1307,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本 thinkingEnabled := claudeReq.Thinking != nil && (claudeReq.Thinking.Type == "enabled" || claudeReq.Thinking.Type == "adaptive") mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled) + billingModel := mappedModel // 获取 access_token if s.tokenProvider == nil { @@ -1372,6 +1369,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, ForceCacheBilling: switchErr.IsStickySession, } } + // 区分客户端取消和真正的上游失败,返回更准确的错误消息 + if c.Request.Context().Err() != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "client_disconnected", "Client disconnected before upstream response") + } return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") } resp := result.resp @@ -1420,7 +1421,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, continue } - log.Printf("Antigravity account %d: detected signature-related 400, retrying once (%s)", account.ID, stage.name) + logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: detected signature-related 400, retrying once (%s)", account.ID, stage.name) retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, s.getClaudeTransformOptions(ctx)) if txErr != nil { @@ -1453,7 +1454,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, Kind: "signature_retry_request_error", Message: sanitizeUpstreamErrorMessage(retryErr.Error()), }) - log.Printf("Antigravity account %d: signature retry request failed (%s): %v", account.ID, stage.name, retryErr) + logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: signature retry request failed (%s): %v", account.ID, stage.name, retryErr) continue } @@ -1472,7 +1473,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, if retryResp.Request != nil && retryResp.Request.URL != nil { retryBaseURL = retryResp.Request.URL.Scheme + "://" + retryResp.Request.URL.Host } - log.Printf("%s status=429 rate_limited base_url=%s retry_stage=%s body=%s", prefix, retryBaseURL, stage.name, truncateForLog(retryBody, 200)) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 rate_limited base_url=%s retry_stage=%s body=%s", prefix, retryBaseURL, stage.name, truncateForLog(retryBody, 200)) } kind := "signature_retry" if strings.TrimSpace(stage.name) != "" { @@ -1525,7 +1526,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, upstreamDetail := s.getUpstreamErrorDetail(respBody) logBody, maxBytes := s.getLogConfig() if logBody { - log.Printf("%s status=400 prompt_too_long=true upstream_message=%q request_id=%s body=%s", prefix, upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, maxBytes)) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=400 prompt_too_long=true upstream_message=%q request_id=%s body=%s", prefix, upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, maxBytes)) } appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, @@ -1600,7 +1601,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 客户端要求流式,直接透传转换 streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel) if err != nil { - log.Printf("%s status=stream_error error=%v", prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err) return nil, err } usage = streamRes.usage @@ -1610,7 +1611,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 客户端要求非流式,收集流式响应后转换返回 streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel) if err != nil { - log.Printf("%s status=stream_collect_error error=%v", prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err) return nil, err } usage = streamRes.usage @@ -1620,7 +1621,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return &ForwardResult{ RequestID: requestID, Usage: *usage, - Model: originalModel, // 使用原始模型用于计费和日志 + Model: billingModel, // 使用映射模型用于计费和日志 Stream: claudeReq.Stream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, @@ -1963,7 +1964,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co Usage: ClaudeUsage{}, Model: originalModel, Stream: false, - Duration: time.Since(time.Now()), + Duration: time.Since(startTime), FirstTokenMs: nil, }, nil default: @@ -1974,6 +1975,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co if mappedModel == "" { return nil, s.writeGoogleError(c, http.StatusForbidden, fmt.Sprintf("model %s not in whitelist", originalModel)) } + billingModel := mappedModel // 获取 access_token if s.tokenProvider == nil { @@ -2002,9 +2004,9 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co // 清理 Schema if cleanedBody, err := cleanGeminiRequest(injectedBody); err == nil { injectedBody = cleanedBody - log.Printf("[Antigravity] Cleaned request schema in forwarded request for account %s", account.Name) + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Cleaned request schema in forwarded request for account %s", account.Name) } else { - log.Printf("[Antigravity] Failed to clean schema: %v", err) + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Failed to clean schema: %v", err) } // 包装请求 @@ -2044,6 +2046,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co ForceCacheBilling: switchErr.IsStickySession, } } + // 区分客户端取消和真正的上游失败,返回更准确的错误消息 + if c.Request.Context().Err() != nil { + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Client disconnected before upstream response") + } return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") } resp := result.resp @@ -2066,7 +2072,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co isModelNotFoundError(resp.StatusCode, respBody) { fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity) if fallbackModel != "" && fallbackModel != mappedModel { - log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name) + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name) fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody) if err == nil { @@ -2149,7 +2155,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co Message: upstreamMsg, Detail: upstreamDetail, }) - log.Printf("[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(unwrappedForOps, 500)) + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(unwrappedForOps, 500)) c.Data(resp.StatusCode, contentType, unwrappedForOps) return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) } @@ -2168,7 +2174,7 @@ handleSuccess: // 客户端要求流式,直接透传 streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime) if err != nil { - log.Printf("%s status=stream_error error=%v", prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err) return nil, err } usage = streamRes.usage @@ -2178,7 +2184,7 @@ handleSuccess: // 客户端要求非流式,收集流式响应后返回 streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime) if err != nil { - log.Printf("%s status=stream_collect_error error=%v", prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err) return nil, err } usage = streamRes.usage @@ -2199,7 +2205,7 @@ handleSuccess: return &ForwardResult{ RequestID: requestID, Usage: *usage, - Model: originalModel, + Model: billingModel, Stream: stream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, @@ -2284,7 +2290,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool { // isSingleAccountRetry 检查 context 中是否设置了单账号退避重试标记 func isSingleAccountRetry(ctx context.Context) bool { - v, _ := ctx.Value(ctxkey.SingleAccountRetry).(bool) + v, _ := SingleAccountRetryFromContext(ctx) return v } @@ -2297,13 +2303,13 @@ func setModelRateLimitByModelName(ctx context.Context, repo AccountRepository, a } // 直接使用官方模型 ID 作为 key,不再转换为 scope if err := repo.SetModelRateLimit(ctx, accountID, modelName, resetAt); err != nil { - log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err) return false } if afterSmartRetry { - log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second)) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second)) } else { - log.Printf("%s status=%d model_rate_limited model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second)) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limited model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second)) } return true } @@ -2411,7 +2417,7 @@ func parseAntigravitySmartRetryInfo(body []byte) *antigravitySmartRetryInfo { // 例如: "0.5s", "10s", "4m50s", "1h30m", "200ms" 等 dur, err := time.ParseDuration(delay) if err != nil { - log.Printf("[Antigravity] failed to parse retryDelay: %s error=%v", delay, err) + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] failed to parse retryDelay: %s error=%v", delay, err) continue } retryDelay = dur @@ -2532,7 +2538,7 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit // RATE_LIMIT_EXCEEDED: < antigravityRateLimitThreshold: 等待后重试 if info.RetryDelay < antigravityRateLimitThreshold { - log.Printf("%s status=%d model_rate_limit_wait model=%s wait=%v", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limit_wait model=%s wait=%v", p.prefix, p.statusCode, info.ModelName, info.RetryDelay) return &handleModelRateLimitResult{ Handled: true, @@ -2557,12 +2563,12 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit // setModelRateLimitAndClearSession 设置模型限流并清除粘性会话 func (s *AntigravityGatewayService) setModelRateLimitAndClearSession(p *handleModelRateLimitParams, info *antigravitySmartRetryInfo) { resetAt := time.Now().Add(info.RetryDelay) - log.Printf("%s status=%d model_rate_limited model=%s account=%d reset_in=%v", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limited model=%s account=%d reset_in=%v", p.prefix, p.statusCode, info.ModelName, p.account.ID, info.RetryDelay) // 设置模型限流状态(数据库) if err := s.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, info.ModelName, resetAt); err != nil { - log.Printf("%s model_rate_limit_failed model=%s error=%v", p.prefix, info.ModelName, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s model_rate_limit_failed model=%s error=%v", p.prefix, info.ModelName, err) } // 立即更新 Redis 快照中账号的限流状态,避免并发请求重复选中 @@ -2598,7 +2604,7 @@ func (s *AntigravityGatewayService) updateAccountModelRateLimitInCache(ctx conte // 更新 Redis 快照 if err := s.schedulerSnapshot.UpdateAccountInCache(ctx, account); err != nil { - log.Printf("[antigravity-Forward] cache_update_failed account=%d model=%s err=%v", account.ID, modelKey, err) + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] cache_update_failed account=%d model=%s err=%v", account.ID, modelKey, err) } } @@ -2637,20 +2643,29 @@ func (s *AntigravityGatewayService) handleUpstreamError( // 429:尝试解析模型级限流,解析失败时兜底为账号级限流 if statusCode == 429 { if logBody, maxBytes := s.getLogConfig(); logBody { - log.Printf("[Antigravity-Debug] 429 response body: %s", truncateString(string(body), maxBytes)) + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity-Debug] 429 response body: %s", truncateString(string(body), maxBytes)) } resetAt := ParseGeminiRateLimitResetTime(body) defaultDur := s.getDefaultRateLimitDuration() // 尝试解析模型 key 并设置模型级限流 - modelKey := resolveAntigravityModelKey(requestedModel) + // + // 注意:requestedModel 可能是"映射前"的请求模型名(例如 claude-opus-4-6), + // 调度与限流判定使用的是 Antigravity 最终模型名(包含映射与 thinking 后缀)。 + // 因此这里必须写入最终模型 key,确保后续调度能正确避开已限流模型。 + modelKey := resolveFinalAntigravityModelKey(ctx, account, requestedModel) + if strings.TrimSpace(modelKey) == "" { + // 极少数情况下无法映射(理论上不应发生:能转发成功说明映射已通过), + // 保持旧行为作为兜底,避免完全丢失模型级限流记录。 + modelKey = resolveAntigravityModelKey(requestedModel) + } if modelKey != "" { ra := s.resolveResetTime(resetAt, defaultDur) if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil { - log.Printf("%s status=429 model_rate_limit_set_failed model=%s error=%v", prefix, modelKey, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 model_rate_limit_set_failed model=%s error=%v", prefix, modelKey, err) } else { - log.Printf("%s status=429 model_rate_limited model=%s account=%d reset_at=%v reset_in=%v", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 model_rate_limited model=%s account=%d reset_at=%v reset_in=%v", prefix, modelKey, account.ID, ra.Format("15:04:05"), time.Until(ra).Truncate(time.Second)) s.updateAccountModelRateLimitInCache(ctx, account, modelKey, ra) } @@ -2659,10 +2674,10 @@ func (s *AntigravityGatewayService) handleUpstreamError( // 无法解析模型 key,兜底为账号级限流 ra := s.resolveResetTime(resetAt, defaultDur) - log.Printf("%s status=429 rate_limited account=%d reset_at=%v reset_in=%v (fallback)", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 rate_limited account=%d reset_at=%v reset_in=%v (fallback)", prefix, account.ID, ra.Format("15:04:05"), time.Until(ra).Truncate(time.Second)) if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil { - log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err) } return nil } @@ -2672,7 +2687,7 @@ func (s *AntigravityGatewayService) handleUpstreamError( } shouldDisable := s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body) if shouldDisable { - log.Printf("%s status=%d marked_error", prefix, statusCode) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d marked_error", prefix, statusCode) } return nil } @@ -2746,18 +2761,18 @@ func (cw *antigravityClientWriter) Disconnected() bool { return cw.disconnected func (cw *antigravityClientWriter) markDisconnected() { cw.disconnected = true - log.Printf("Client disconnected during streaming (%s), continuing to drain upstream for billing", cw.prefix) + logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during streaming (%s), continuing to drain upstream for billing", cw.prefix) } // handleStreamReadError 处理上游读取错误的通用逻辑。 // 返回 (clientDisconnect, handled):handled=true 表示错误已处理,调用方应返回已收集的 usage。 func handleStreamReadError(err error, clientDisconnected bool, prefix string) (disconnect bool, handled bool) { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - log.Printf("Context canceled during streaming (%s), returning collected usage", prefix) + logger.LegacyPrintf("service.antigravity_gateway", "Context canceled during streaming (%s), returning collected usage", prefix) return true, true } if clientDisconnected { - log.Printf("Upstream read error after client disconnect (%s): %v, returning collected usage", prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "Upstream read error after client disconnect (%s): %v, returning collected usage", prefix, err) return true, true } return false, false @@ -2786,7 +2801,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.settingService.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) usage := &ClaudeUsage{} var firstTokenMs *int @@ -2807,7 +2823,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context } var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -2818,7 +2835,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) // 上游数据间隔超时保护(防止上游挂起长期占用连接) @@ -2860,7 +2877,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil } if errors.Is(ev.err, bufio.ErrTooLong) { - log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) + logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) sendErrorEvent("response_too_large") return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err } @@ -2884,19 +2901,19 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context } // 解析 usage + if u := extractGeminiUsage(inner); u != nil { + usage = u + } var parsed map[string]any if json.Unmarshal(inner, &parsed) == nil { - if u := extractGeminiUsage(parsed); u != nil { - usage = u - } // Check for MALFORMED_FUNCTION_CALL if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 { if cand, ok := candidates[0].(map[string]any); ok { if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" { - log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream") + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream") if content, ok := cand["content"]; ok { if b, err := json.Marshal(content); err == nil { - log.Printf("[Antigravity] Malformed content: %s", string(b)) + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Malformed content: %s", string(b)) } } } @@ -2921,10 +2938,10 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context continue } if cw.Disconnected() { - log.Printf("Upstream timeout after client disconnect (antigravity gemini), returning collected usage") + logger.LegacyPrintf("service.antigravity_gateway", "Upstream timeout after client disconnect (antigravity gemini), returning collected usage") return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil } - log.Printf("Stream data interval timeout (antigravity)") + logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)") sendErrorEvent("stream_timeout") return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") } @@ -2939,7 +2956,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.settingService.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) usage := &ClaudeUsage{} var firstTokenMs *int @@ -2967,7 +2985,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -2978,7 +2997,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) // 上游数据间隔超时保护(防止上游挂起长期占用连接) @@ -3005,7 +3024,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont } if ev.err != nil { if errors.Is(ev.err, bufio.ErrTooLong) { - log.Printf("SSE line too long (antigravity non-stream): max_size=%d error=%v", maxLineSize, ev.err) + logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity non-stream): max_size=%d error=%v", maxLineSize, ev.err) } return nil, ev.err } @@ -3042,7 +3061,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont last = parsed // 提取 usage - if u := extractGeminiUsage(parsed); u != nil { + if u := extractGeminiUsage(inner); u != nil { usage = u } @@ -3050,10 +3069,10 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 { if cand, ok := candidates[0].(map[string]any); ok { if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" { - log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect") + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect") if content, ok := cand["content"]; ok { if b, err := json.Marshal(content); err == nil { - log.Printf("[Antigravity] Malformed content: %s", string(b)) + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Malformed content: %s", string(b)) } } } @@ -3080,7 +3099,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont if time.Since(lastRead) < streamInterval { continue } - log.Printf("Stream data interval timeout (antigravity non-stream)") + logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity non-stream)") return nil, fmt.Errorf("stream data interval timeout") } } @@ -3091,7 +3110,7 @@ returnResponse: // 处理空响应情况 — 触发同账号重试 + failover 切换账号 if last == nil && lastWithParts == nil { - log.Printf("[antigravity-Forward] warning: empty stream response (gemini non-stream), triggering failover") + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] warning: empty stream response (gemini non-stream), triggering failover") return nil, &UpstreamFailoverError{ StatusCode: http.StatusBadGateway, ResponseBody: []byte(`{"error":"empty stream response from upstream"}`), @@ -3311,7 +3330,7 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou // 记录上游错误详情便于排障(可选:由配置控制;不回显到客户端) if logBody { - log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes)) + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes)) } // 检查错误透传规则 @@ -3402,7 +3421,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.settingService.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) var firstTokenMs *int var last map[string]any @@ -3428,7 +3448,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -3439,7 +3460,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) // 上游数据间隔超时保护(防止上游挂起长期占用连接) @@ -3466,7 +3487,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont } if ev.err != nil { if errors.Is(ev.err, bufio.ErrTooLong) { - log.Printf("SSE line too long (antigravity claude non-stream): max_size=%d error=%v", maxLineSize, ev.err) + logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity claude non-stream): max_size=%d error=%v", maxLineSize, ev.err) } return nil, ev.err } @@ -3515,7 +3536,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont if time.Since(lastRead) < streamInterval { continue } - log.Printf("Stream data interval timeout (antigravity claude non-stream)") + logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity claude non-stream)") return nil, fmt.Errorf("stream data interval timeout") } } @@ -3526,7 +3547,7 @@ returnResponse: // 处理空响应情况 — 触发同账号重试 + failover 切换账号 if last == nil && lastWithParts == nil { - log.Printf("[antigravity-Forward] warning: empty stream response (claude non-stream), triggering failover") + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] warning: empty stream response (claude non-stream), triggering failover") return nil, &UpstreamFailoverError{ StatusCode: http.StatusBadGateway, ResponseBody: []byte(`{"error":"empty stream response from upstream"}`), @@ -3548,7 +3569,7 @@ returnResponse: // 转换 Gemini 响应为 Claude 格式 claudeResp, agUsage, err := antigravity.TransformGeminiToClaude(geminiBody, originalModel) if err != nil { - log.Printf("[antigravity-Forward] transform_error error=%v body=%s", err, string(geminiBody)) + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] transform_error error=%v body=%s", err, string(geminiBody)) return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") } @@ -3586,7 +3607,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.settingService.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) // 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage { @@ -3618,7 +3640,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context } var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -3629,7 +3652,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) streamInterval := time.Duration(0) @@ -3681,7 +3704,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil } if errors.Is(ev.err, bufio.ErrTooLong) { - log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) + logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) sendErrorEvent("response_too_large") return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err } @@ -3705,10 +3728,10 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context continue } if cw.Disconnected() { - log.Printf("Upstream timeout after client disconnect (antigravity claude), returning collected usage") + logger.LegacyPrintf("service.antigravity_gateway", "Upstream timeout after client disconnect (antigravity claude), returning collected usage") return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: true}, nil } - log.Printf("Stream data interval timeout (antigravity)") + logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)") sendErrorEvent("stream_timeout") return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") } @@ -3733,14 +3756,17 @@ func (s *AntigravityGatewayService) extractImageSize(body []byte) string { } // isImageGenerationModel 判断模型是否为图片生成模型 -// 支持的模型:gemini-3-pro-image, gemini-3-pro-image-preview, gemini-2.5-flash-image 等 +// 支持的模型:gemini-3.1-flash-image, gemini-3-pro-image, gemini-2.5-flash-image 等 func isImageGenerationModel(model string) bool { modelLower := strings.ToLower(model) // 移除 models/ 前缀 modelLower = strings.TrimPrefix(modelLower, "models/") // 精确匹配或前缀匹配 - return modelLower == "gemini-3-pro-image" || + return modelLower == "gemini-3.1-flash-image" || + modelLower == "gemini-3.1-flash-image-preview" || + strings.HasPrefix(modelLower, "gemini-3.1-flash-image-") || + modelLower == "gemini-3-pro-image" || modelLower == "gemini-3-pro-image-preview" || strings.HasPrefix(modelLower, "gemini-3-pro-image-") || modelLower == "gemini-2.5-flash-image" || @@ -3875,7 +3901,6 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. return nil, fmt.Errorf("missing model") } originalModel := claudeReq.Model - billingModel := originalModel // 构建上游请求 URL upstreamURL := baseURL + "/v1/messages" @@ -3908,7 +3933,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. // 发送请求 resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) if err != nil { - log.Printf("%s upstream request failed: %v", prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s upstream request failed: %v", prefix, err) return nil, fmt.Errorf("upstream request failed: %w", err) } defer func() { _ = resp.Body.Close() }() @@ -3928,7 +3953,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. _, _ = c.Writer.Write(respBody) return &ForwardResult{ - Model: billingModel, + Model: originalModel, }, nil } @@ -3966,10 +3991,10 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. // 构建计费结果 duration := time.Since(startTime) - log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds()) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=success duration_ms=%d", prefix, duration.Milliseconds()) return &ForwardResult{ - Model: billingModel, + Model: originalModel, Stream: claudeReq.Stream, Duration: duration, FirstTokenMs: firstTokenMs, @@ -4052,7 +4077,7 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity upstream"); handled { return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect} } - log.Printf("Stream read error (antigravity upstream): %v", ev.err) + logger.LegacyPrintf("service.antigravity_gateway", "Stream read error (antigravity upstream): %v", ev.err) return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs} } @@ -4076,10 +4101,10 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp continue } if cw.Disconnected() { - log.Printf("Upstream timeout after client disconnect (antigravity upstream), returning collected usage") + logger.LegacyPrintf("service.antigravity_gateway", "Upstream timeout after client disconnect (antigravity upstream), returning collected usage") return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true} } - log.Printf("Stream data interval timeout (antigravity upstream)") + logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity upstream)") return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs} } } @@ -4111,6 +4136,15 @@ func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUs if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 { usage.CacheCreationInputTokens = int(v) } + // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 + if cc, ok := u["cache_creation"].(map[string]any); ok { + if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok { + usage.CacheCreation5mTokens = int(v) + } + if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok { + usage.CacheCreation1hTokens = int(v) + } + } } // extractClaudeUsage 从非流式 Claude 响应提取 usage @@ -4133,6 +4167,15 @@ func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage if v, ok := u["cache_creation_input_tokens"].(float64); ok { usage.CacheCreationInputTokens = int(v) } + // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 + if cc, ok := u["cache_creation"].(map[string]any); ok { + if v, ok := cc["ephemeral_5m_input_tokens"].(float64); ok { + usage.CacheCreation5mTokens = int(v) + } + if v, ok := cc["ephemeral_1h_input_tokens"].(float64); ok { + usage.CacheCreation1hTokens = int(v) + } + } } return usage } diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index b312e5ca..84b65adc 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -133,6 +134,36 @@ func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, return s.resp, s.err } +type antigravitySettingRepoStub struct{} + +func (s *antigravitySettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *antigravitySettingRepoStub) GetValue(ctx context.Context, key string) (string, error) { + return "", ErrSettingNotFound +} + +func (s *antigravitySettingRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *antigravitySettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + panic("unexpected GetMultiple call") +} + +func (s *antigravitySettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *antigravitySettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *antigravitySettingRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) { gin.SetMode(gin.TestMode) writer := httptest.NewRecorder() @@ -159,8 +190,9 @@ func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) { } svc := &AntigravityGatewayService{ - tokenProvider: &AntigravityTokenProvider{}, - httpUpstream: &httpUpstreamStub{resp: resp}, + settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}), + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: &httpUpstreamStub{resp: resp}, } account := &Account{ @@ -417,6 +449,151 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling( require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch") } +// TestAntigravityGatewayService_Forward_BillsWithMappedModel +// 验证:Antigravity Claude 转发返回的计费模型使用映射后的模型 +func TestAntigravityGatewayService_Forward_BillsWithMappedModel(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + + body, err := json.Marshal(map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{ + {"role": "user", "content": "hello"}, + }, + "max_tokens": 16, + "stream": true, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request = req + + upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"X-Request-Id": []string{"req-bill-1"}}, + Body: io.NopCloser(bytes.NewReader(upstreamBody)), + } + + svc := &AntigravityGatewayService{ + settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}), + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: &httpUpstreamStub{resp: resp}, + } + + const mappedModel = "gemini-3-pro-high" + account := &Account{ + ID: 5, + Name: "acc-forward-billing", + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "token", + "model_mapping": map[string]any{ + "claude-sonnet-4-5": mappedModel, + }, + }, + } + + result, err := svc.Forward(context.Background(), c, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, mappedModel, result.Model) +} + +// TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel +// 验证:Antigravity Gemini 转发返回的计费模型使用映射后的模型 +func TestAntigravityGatewayService_ForwardGemini_BillsWithMappedModel(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + + body, err := json.Marshal(map[string]any{ + "contents": []map[string]any{ + {"role": "user", "parts": []map[string]any{{"text": "hello"}}}, + }, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body)) + c.Request = req + + upstreamBody := []byte("data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}}\n\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"X-Request-Id": []string{"req-bill-2"}}, + Body: io.NopCloser(bytes.NewReader(upstreamBody)), + } + + svc := &AntigravityGatewayService{ + settingService: NewSettingService(&antigravitySettingRepoStub{}, &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}), + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: &httpUpstreamStub{resp: resp}, + } + + const mappedModel = "gemini-3-pro-high" + account := &Account{ + ID: 6, + Name: "acc-gemini-billing", + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "token", + "model_mapping": map[string]any{ + "gemini-2.5-flash": mappedModel, + }, + }, + } + + result, err := svc.ForwardGemini(context.Background(), c, account, "gemini-2.5-flash", "generateContent", true, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, mappedModel, result.Model) +} + +// TestStreamUpstreamResponse_UsageAndFirstToken +// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间 +func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + defer func() { _ = pw.Close() }() + fmt.Fprintln(pw, `data: {"usage":{"input_tokens":1,"output_tokens":2,"cache_read_input_tokens":3,"cache_creation_input_tokens":4}}`) + fmt.Fprintln(pw, `data: {"usage":{"output_tokens":5}}`) + }() + + start := time.Now().Add(-10 * time.Millisecond) + result := svc.streamUpstreamResponse(c, resp, start) + _ = pr.Close() + + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 1, result.usage.InputTokens) + // 第二次事件覆盖 output_tokens + require.Equal(t, 5, result.usage.OutputTokens) + require.Equal(t, 3, result.usage.CacheReadInputTokens) + require.Equal(t, 4, result.usage.CacheCreationInputTokens) + require.NotNil(t, result.firstTokenMs) + + // 确保有透传输出 + require.Contains(t, rec.Body.String(), "data:") +} + // --- 流式 happy path 测试 --- // TestStreamUpstreamResponse_NormalComplete @@ -920,3 +1097,144 @@ func TestAntigravityClientWriter(t *testing.T) { require.True(t, cw.Disconnected()) }) } + +// TestUnwrapV1InternalResponse 测试 unwrapV1InternalResponse 的各种输入场景 +func TestUnwrapV1InternalResponse(t *testing.T) { + svc := &AntigravityGatewayService{} + + // 构造 >50KB 的大型 JSON + largePadding := strings.Repeat("x", 50*1024) + largeInput := []byte(fmt.Sprintf(`{"response":{"id":"big","pad":"%s"}}`, largePadding)) + largeExpected := fmt.Sprintf(`{"id":"big","pad":"%s"}`, largePadding) + + tests := []struct { + name string + input []byte + expected string + wantErr bool + }{ + { + name: "正常 response 包装", + input: []byte(`{"response":{"id":"123","content":"hello"}}`), + expected: `{"id":"123","content":"hello"}`, + }, + { + name: "无 response 透传", + input: []byte(`{"id":"456"}`), + expected: `{"id":"456"}`, + }, + { + name: "空 JSON", + input: []byte(`{}`), + expected: `{}`, + }, + { + name: "response 为 null", + input: []byte(`{"response":null}`), + expected: `null`, + }, + { + name: "response 为基础类型 string", + input: []byte(`{"response":"hello"}`), + expected: `"hello"`, + }, + { + name: "非法 JSON", + input: []byte(`not json`), + expected: `not json`, + }, + { + name: "嵌套 response 只解一层", + input: []byte(`{"response":{"response":{"inner":true}}}`), + expected: `{"response":{"inner":true}}`, + }, + { + name: "大型 JSON >50KB", + input: largeInput, + expected: largeExpected, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := svc.unwrapV1InternalResponse(tt.input) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.expected, strings.TrimSpace(string(got))) + }) + } +} + +// --- unwrapV1InternalResponse benchmark 对照组 --- + +// unwrapV1InternalResponseOld 旧实现:Unmarshal+Marshal 双重开销(仅用于 benchmark 对照) +func unwrapV1InternalResponseOld(body []byte) ([]byte, error) { + var outer map[string]any + if err := json.Unmarshal(body, &outer); err != nil { + return nil, err + } + if resp, ok := outer["response"]; ok { + return json.Marshal(resp) + } + return body, nil +} + +func BenchmarkUnwrapV1Internal_Old_Small(b *testing.B) { + body := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"hello world"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}}`) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = unwrapV1InternalResponseOld(body) + } +} + +func BenchmarkUnwrapV1Internal_New_Small(b *testing.B) { + body := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"hello world"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}}`) + svc := &AntigravityGatewayService{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = svc.unwrapV1InternalResponse(body) + } +} + +func BenchmarkUnwrapV1Internal_Old_Large(b *testing.B) { + body := generateLargeUnwrapJSON(10 * 1024) // ~10KB + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = unwrapV1InternalResponseOld(body) + } +} + +func BenchmarkUnwrapV1Internal_New_Large(b *testing.B) { + body := generateLargeUnwrapJSON(10 * 1024) // ~10KB + svc := &AntigravityGatewayService{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = svc.unwrapV1InternalResponse(body) + } +} + +// generateLargeUnwrapJSON 生成指定最小大小的包含 response 包装的 JSON +func generateLargeUnwrapJSON(minSize int) []byte { + parts := make([]map[string]string, 0) + current := 0 + for current < minSize { + text := fmt.Sprintf("这是第 %d 段内容,用于填充 JSON 到目标大小。", len(parts)+1) + parts = append(parts, map[string]string{"text": text}) + current += len(text) + 20 // 估算 JSON 编码开销 + } + inner := map[string]any{ + "candidates": []map[string]any{ + {"content": map[string]any{"parts": parts}}, + }, + "usageMetadata": map[string]any{ + "promptTokenCount": 100, + "candidatesTokenCount": 50, + }, + } + outer := map[string]any{"response": inner} + b, _ := json.Marshal(outer) + return b +} diff --git a/backend/internal/service/antigravity_model_mapping_test.go b/backend/internal/service/antigravity_model_mapping_test.go index f3621555..71939d26 100644 --- a/backend/internal/service/antigravity_model_mapping_test.go +++ b/backend/internal/service/antigravity_model_mapping_test.go @@ -76,6 +76,12 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { }, // 3. 默认映射中的透传(映射到自己) + { + name: "默认映射透传 - claude-sonnet-4-6", + requestedModel: "claude-sonnet-4-6", + accountMapping: nil, + expected: "claude-sonnet-4-6", + }, { name: "默认映射透传 - claude-sonnet-4-5", requestedModel: "claude-sonnet-4-5", diff --git a/backend/internal/service/antigravity_oauth_service.go b/backend/internal/service/antigravity_oauth_service.go index b67c7faf..5f6691be 100644 --- a/backend/internal/service/antigravity_oauth_service.go +++ b/backend/internal/service/antigravity_oauth_service.go @@ -112,7 +112,10 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig } } - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create antigravity client failed: %w", err) + } // 交换 token tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier) @@ -167,7 +170,10 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken time.Sleep(backoff) } - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create antigravity client failed: %w", err) + } tokenResp, err := client.RefreshToken(ctx, refreshToken) if err == nil { now := time.Now() @@ -209,7 +215,10 @@ func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refr } // 获取用户信息(email) - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create antigravity client failed: %w", err) + } userInfo, err := client.GetUserInfo(ctx, tokenInfo.AccessToken) if err != nil { fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err) @@ -309,7 +318,10 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac time.Sleep(backoff) } - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return "", fmt.Errorf("create antigravity client failed: %w", err) + } loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken) if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" { diff --git a/backend/internal/service/antigravity_quota_fetcher.go b/backend/internal/service/antigravity_quota_fetcher.go index 07eb563d..e950ec1d 100644 --- a/backend/internal/service/antigravity_quota_fetcher.go +++ b/backend/internal/service/antigravity_quota_fetcher.go @@ -2,6 +2,7 @@ package service import ( "context" + "fmt" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" @@ -31,7 +32,10 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou accessToken := account.GetCredential("access_token") projectID := account.GetCredential("project_id") - client := antigravity.NewClient(proxyURL) + client, err := antigravity.NewClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("create antigravity client failed: %w", err) + } // 调用 API 获取配额 modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID) diff --git a/backend/internal/service/antigravity_rate_limit_test.go b/backend/internal/service/antigravity_rate_limit_test.go index 0befa7d9..dd8dd83f 100644 --- a/backend/internal/service/antigravity_rate_limit_test.go +++ b/backend/internal/service/antigravity_rate_limit_test.go @@ -15,6 +15,12 @@ import ( "github.com/stretchr/testify/require" ) +// 编译期接口断言 +var _ HTTPUpstream = (*stubAntigravityUpstream)(nil) +var _ HTTPUpstream = (*recordingOKUpstream)(nil) +var _ AccountRepository = (*stubAntigravityAccountRepo)(nil) +var _ SchedulerCache = (*stubSchedulerCache)(nil) + type stubAntigravityUpstream struct { firstBase string secondBase string @@ -191,6 +197,22 @@ func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) { require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) } +// TestHandleUpstreamError_429_NonModelRateLimit_UsesMappedModelKey 测试 429 非模型限流场景 +// 验证:requestedModel 会被映射到 Antigravity 最终模型(例如 claude-opus-4-6 -> claude-opus-4-6-thinking) +func TestHandleUpstreamError_429_NonModelRateLimit_UsesMappedModelKey(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 20, Name: "acc-20", Platform: PlatformAntigravity} + + body := buildGeminiRateLimitBody("5s") + + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-opus-4-6", 0, "", false) + + require.Nil(t, result) + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "claude-opus-4-6-thinking", repo.modelRateLimitCalls[0].modelKey) +} + // TestHandleUpstreamError_503_ModelCapacityExhausted 测试 503 模型容量不足场景 // MODEL_CAPACITY_EXHAUSTED 时应等待重试,不切换账号 func TestHandleUpstreamError_503_ModelCapacityExhausted(t *testing.T) { diff --git a/backend/internal/service/api_key.go b/backend/internal/service/api_key.go index d66059dd..4c565495 100644 --- a/backend/internal/service/api_key.go +++ b/backend/internal/service/api_key.go @@ -1,6 +1,10 @@ package service -import "time" +import ( + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" +) // API Key status constants const ( @@ -19,21 +23,41 @@ type APIKey struct { Status string IPWhitelist []string IPBlacklist []string - CreatedAt time.Time - UpdatedAt time.Time - User *User - Group *Group + // 预编译的 IP 规则,用于认证热路径避免重复 ParseIP/ParseCIDR。 + CompiledIPWhitelist *ip.CompiledIPRules `json:"-"` + CompiledIPBlacklist *ip.CompiledIPRules `json:"-"` + LastUsedAt *time.Time + CreatedAt time.Time + UpdatedAt time.Time + User *User + Group *Group // Quota fields Quota float64 // Quota limit in USD (0 = unlimited) QuotaUsed float64 // Used quota amount ExpiresAt *time.Time // Expiration time (nil = never expires) + + // Rate limit fields + RateLimit5h float64 // Rate limit in USD per 5h (0 = unlimited) + RateLimit1d float64 // Rate limit in USD per 1d (0 = unlimited) + RateLimit7d float64 // Rate limit in USD per 7d (0 = unlimited) + Usage5h float64 // Used amount in current 5h window + Usage1d float64 // Used amount in current 1d window + Usage7d float64 // Used amount in current 7d window + Window5hStart *time.Time // Start of current 5h window + Window1dStart *time.Time // Start of current 1d window + Window7dStart *time.Time // Start of current 7d window } func (k *APIKey) IsActive() bool { return k.Status == StatusActive } +// HasRateLimits returns true if any rate limit window is configured +func (k *APIKey) HasRateLimits() bool { + return k.RateLimit5h > 0 || k.RateLimit1d > 0 || k.RateLimit7d > 0 +} + // IsExpired checks if the API key has expired func (k *APIKey) IsExpired() bool { if k.ExpiresAt == nil { @@ -73,3 +97,10 @@ func (k *APIKey) GetDaysUntilExpiry() int { } return int(duration.Hours() / 24) } + +// APIKeyListFilters holds optional filtering parameters for listing API keys. +type APIKeyListFilters struct { + Search string + Status string + GroupID *int64 // nil=不筛选, 0=无分组, >0=指定分组 +} diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index d15b5817..83933f42 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -19,6 +19,11 @@ type APIKeyAuthSnapshot struct { // Expiration field for API Key expiration feature ExpiresAt *time.Time `json:"expires_at,omitempty"` // Expiration time (nil = never expires) + + // Rate limit configuration (only limits, not usage - usage read from Redis at check time) + RateLimit5h float64 `json:"rate_limit_5h"` + RateLimit1d float64 `json:"rate_limit_1d"` + RateLimit7d float64 `json:"rate_limit_7d"` } // APIKeyAuthUserSnapshot 用户快照 @@ -44,6 +49,10 @@ type APIKeyAuthGroupSnapshot struct { ImagePrice1K *float64 `json:"image_price_1k,omitempty"` ImagePrice2K *float64 `json:"image_price_2k,omitempty"` ImagePrice4K *float64 `json:"image_price_4k,omitempty"` + SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"` + SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"` ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"` diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index f5bba7d0..0ca694af 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -6,8 +6,7 @@ import ( "encoding/hex" "errors" "fmt" - "math/rand" - "sync" + "math/rand/v2" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -23,12 +22,6 @@ type apiKeyAuthCacheConfig struct { singleflight bool } -var ( - jitterRandMu sync.Mutex - // 认证缓存抖动使用独立随机源,避免全局 Seed - jitterRand = rand.New(rand.NewSource(time.Now().UnixNano())) -) - func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig { if cfg == nil { return apiKeyAuthCacheConfig{} @@ -56,6 +49,8 @@ func (c apiKeyAuthCacheConfig) negativeEnabled() bool { return c.negativeTTL > 0 } +// jitterTTL 为缓存 TTL 添加抖动,避免多个请求在同一时刻同时过期触发集中回源。 +// 这里直接使用 rand/v2 的顶层函数:并发安全,无需全局互斥锁。 func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration { if ttl <= 0 { return ttl @@ -68,9 +63,7 @@ func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration { percent = 100 } delta := float64(percent) / 100 - jitterRandMu.Lock() - randVal := jitterRand.Float64() - jitterRandMu.Unlock() + randVal := rand.Float64() factor := 1 - delta + randVal*(2*delta) if factor <= 0 { return ttl @@ -216,6 +209,9 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { Quota: apiKey.Quota, QuotaUsed: apiKey.QuotaUsed, ExpiresAt: apiKey.ExpiresAt, + RateLimit5h: apiKey.RateLimit5h, + RateLimit1d: apiKey.RateLimit1d, + RateLimit7d: apiKey.RateLimit7d, User: APIKeyAuthUserSnapshot{ ID: apiKey.User.ID, Status: apiKey.User.Status, @@ -238,6 +234,10 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ImagePrice1K: apiKey.Group.ImagePrice1K, ImagePrice2K: apiKey.Group.ImagePrice2K, ImagePrice4K: apiKey.Group.ImagePrice4K, + SoraImagePrice360: apiKey.Group.SoraImagePrice360, + SoraImagePrice540: apiKey.Group.SoraImagePrice540, + SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, FallbackGroupID: apiKey.Group.FallbackGroupID, FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest, @@ -265,6 +265,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho Quota: snapshot.Quota, QuotaUsed: snapshot.QuotaUsed, ExpiresAt: snapshot.ExpiresAt, + RateLimit5h: snapshot.RateLimit5h, + RateLimit1d: snapshot.RateLimit1d, + RateLimit7d: snapshot.RateLimit7d, User: &User{ ID: snapshot.User.ID, Status: snapshot.User.Status, @@ -288,6 +291,10 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ImagePrice1K: snapshot.Group.ImagePrice1K, ImagePrice2K: snapshot.Group.ImagePrice2K, ImagePrice4K: snapshot.Group.ImagePrice4K, + SoraImagePrice360: snapshot.Group.SoraImagePrice360, + SoraImagePrice540: snapshot.Group.SoraImagePrice540, + SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD, ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, FallbackGroupID: snapshot.Group.FallbackGroupID, FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest, @@ -297,5 +304,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho SupportedModelScopes: snapshot.Group.SupportedModelScopes, } } + s.compileAPIKeyIPRules(apiKey) return apiKey } diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index cb1dd60a..b32a1d67 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -5,6 +5,8 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "strconv" + "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -28,10 +30,18 @@ var ( ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key 已过期") // ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key quota exhausted") ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key 额度已用完") + + // Rate limit errors + ErrAPIKeyRateLimit5hExceeded = infraerrors.TooManyRequests("API_KEY_RATE_5H_EXCEEDED", "api key 5小时限额已用完") + ErrAPIKeyRateLimit1dExceeded = infraerrors.TooManyRequests("API_KEY_RATE_1D_EXCEEDED", "api key 日限额已用完") + ErrAPIKeyRateLimit7dExceeded = infraerrors.TooManyRequests("API_KEY_RATE_7D_EXCEEDED", "api key 7天限额已用完") ) const ( apiKeyMaxErrorsPerHour = 20 + apiKeyLastUsedMinTouch = 30 * time.Second + // DB 写失败后的短退避,避免请求路径持续同步重试造成写风暴与高延迟。 + apiKeyLastUsedFailBackoff = 5 * time.Second ) type APIKeyRepository interface { @@ -45,7 +55,7 @@ type APIKeyRepository interface { Update(ctx context.Context, key *APIKey) error Delete(ctx context.Context, id int64) error - ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) + ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) CountByUserID(ctx context.Context, userID int64) (int64, error) ExistsByKey(ctx context.Context, key string) (bool, error) @@ -58,6 +68,22 @@ type APIKeyRepository interface { // Quota methods IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) + UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error + + // Rate limit methods + IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error + ResetRateLimitWindows(ctx context.Context, id int64) error + GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) +} + +// APIKeyRateLimitData holds rate limit usage and window state for an API key. +type APIKeyRateLimitData struct { + Usage5h float64 + Usage1d float64 + Usage7d float64 + Window5hStart *time.Time + Window1dStart *time.Time + Window7dStart *time.Time } // APIKeyCache defines cache operations for API key service @@ -96,6 +122,11 @@ type CreateAPIKeyRequest struct { // Quota fields Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited) ExpiresInDays *int `json:"expires_in_days"` // Days until expiry (nil = never expires) + + // Rate limit fields (0 = unlimited) + RateLimit5h float64 `json:"rate_limit_5h"` + RateLimit1d float64 `json:"rate_limit_1d"` + RateLimit7d float64 `json:"rate_limit_7d"` } // UpdateAPIKeyRequest 更新API Key请求 @@ -111,20 +142,34 @@ type UpdateAPIKeyRequest struct { ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = no change) ClearExpiration bool `json:"-"` // Clear expiration (internal use) ResetQuota *bool `json:"reset_quota"` // Reset quota_used to 0 + + // Rate limit fields (nil = no change, 0 = unlimited) + RateLimit5h *float64 `json:"rate_limit_5h"` + RateLimit1d *float64 `json:"rate_limit_1d"` + RateLimit7d *float64 `json:"rate_limit_7d"` + ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // Reset all usage counters to 0 } // APIKeyService API Key服务 +// RateLimitCacheInvalidator invalidates rate limit cache entries on manual reset. +type RateLimitCacheInvalidator interface { + InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error +} + type APIKeyService struct { - apiKeyRepo APIKeyRepository - userRepo UserRepository - groupRepo GroupRepository - userSubRepo UserSubscriptionRepository - userGroupRateRepo UserGroupRateRepository - cache APIKeyCache - cfg *config.Config - authCacheL1 *ristretto.Cache - authCfg apiKeyAuthCacheConfig - authGroup singleflight.Group + apiKeyRepo APIKeyRepository + userRepo UserRepository + groupRepo GroupRepository + userSubRepo UserSubscriptionRepository + userGroupRateRepo UserGroupRateRepository + cache APIKeyCache + rateLimitCacheInvalid RateLimitCacheInvalidator // optional: invalidate Redis rate limit cache + cfg *config.Config + authCacheL1 *ristretto.Cache + authCfg apiKeyAuthCacheConfig + authGroup singleflight.Group + lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time) + lastUsedTouchSF singleflight.Group } // NewAPIKeyService 创建API Key服务实例 @@ -150,6 +195,20 @@ func NewAPIKeyService( return svc } +// SetRateLimitCacheInvalidator sets the optional rate limit cache invalidator. +// Called after construction (e.g. in wire) to avoid circular dependencies. +func (s *APIKeyService) SetRateLimitCacheInvalidator(inv RateLimitCacheInvalidator) { + s.rateLimitCacheInvalid = inv +} + +func (s *APIKeyService) compileAPIKeyIPRules(apiKey *APIKey) { + if apiKey == nil { + return + } + apiKey.CompiledIPWhitelist = ip.CompileIPRules(apiKey.IPWhitelist) + apiKey.CompiledIPBlacklist = ip.CompileIPRules(apiKey.IPBlacklist) +} + // GenerateKey 生成随机API Key func (s *APIKeyService) GenerateKey() (string, error) { // 生成32字节随机数据 @@ -311,6 +370,9 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK IPBlacklist: req.IPBlacklist, Quota: req.Quota, QuotaUsed: 0, + RateLimit5h: req.RateLimit5h, + RateLimit1d: req.RateLimit1d, + RateLimit7d: req.RateLimit7d, } // Set expiration time if specified @@ -324,13 +386,14 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK } s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } // List 获取用户的API Key列表 -func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { - keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) +func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) { + keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, filters) if err != nil { return nil, nil, fmt.Errorf("list api keys: %w", err) } @@ -355,6 +418,7 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) if err != nil { return nil, fmt.Errorf("get api key: %w", err) } + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } @@ -367,6 +431,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro if err != nil { return nil, fmt.Errorf("get api key: %w", err) } + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } } @@ -383,6 +448,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro if err != nil { return nil, fmt.Errorf("get api key: %w", err) } + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } } else { @@ -394,6 +460,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro if err != nil { return nil, fmt.Errorf("get api key: %w", err) } + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } } @@ -403,6 +470,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro return nil, fmt.Errorf("get api key: %w", err) } apiKey.Key = key + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } @@ -497,11 +565,37 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req apiKey.IPWhitelist = req.IPWhitelist apiKey.IPBlacklist = req.IPBlacklist + // Update rate limit configuration + if req.RateLimit5h != nil { + apiKey.RateLimit5h = *req.RateLimit5h + } + if req.RateLimit1d != nil { + apiKey.RateLimit1d = *req.RateLimit1d + } + if req.RateLimit7d != nil { + apiKey.RateLimit7d = *req.RateLimit7d + } + resetRateLimit := req.ResetRateLimitUsage != nil && *req.ResetRateLimitUsage + if resetRateLimit { + apiKey.Usage5h = 0 + apiKey.Usage1d = 0 + apiKey.Usage7d = 0 + apiKey.Window5hStart = nil + apiKey.Window1dStart = nil + apiKey.Window7dStart = nil + } + if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { return nil, fmt.Errorf("update api key: %w", err) } s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + s.compileAPIKeyIPRules(apiKey) + + // Invalidate Redis rate limit cache so reset takes effect immediately + if resetRateLimit && s.rateLimitCacheInvalid != nil { + _ = s.rateLimitCacheInvalid.InvalidateAPIKeyRateLimit(ctx, apiKey.ID) + } return apiKey, nil } @@ -527,6 +621,7 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro if err := s.apiKeyRepo.Delete(ctx, id); err != nil { return fmt.Errorf("delete api key: %w", err) } + s.lastUsedTouchL1.Delete(id) return nil } @@ -558,6 +653,38 @@ func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, * return apiKey, user, nil } +// TouchLastUsed 通过防抖更新 api_keys.last_used_at,减少高频写放大。 +// 该操作为尽力而为,不应阻塞主请求链路。 +func (s *APIKeyService) TouchLastUsed(ctx context.Context, keyID int64) error { + if keyID <= 0 { + return nil + } + + now := time.Now() + if v, ok := s.lastUsedTouchL1.Load(keyID); ok { + if nextAllowedAt, ok := v.(time.Time); ok && now.Before(nextAllowedAt) { + return nil + } + } + + _, err, _ := s.lastUsedTouchSF.Do(strconv.FormatInt(keyID, 10), func() (any, error) { + latest := time.Now() + if v, ok := s.lastUsedTouchL1.Load(keyID); ok { + if nextAllowedAt, ok := v.(time.Time); ok && latest.Before(nextAllowedAt) { + return nil, nil + } + } + + if err := s.apiKeyRepo.UpdateLastUsed(ctx, keyID, latest); err != nil { + s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedFailBackoff)) + return nil, fmt.Errorf("touch api key last used: %w", err) + } + s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedMinTouch)) + return nil, nil + }) + return err +} + // IncrementUsage 增加API Key使用次数(可选:用于统计) func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error { // 使用Redis计数器 @@ -690,3 +817,16 @@ func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cos return nil } + +// GetRateLimitData returns rate limit usage and window state for an API key. +func (s *APIKeyService) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) { + return s.apiKeyRepo.GetRateLimitData(ctx, id) +} + +// UpdateRateLimitUsage atomically increments rate limit usage counters in the DB. +func (s *APIKeyService) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error { + if cost <= 0 { + return nil + } + return s.apiKeyRepo.IncrementRateLimitUsage(ctx, apiKeyID, cost) +} diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go index 14ecbf39..97b8e229 100644 --- a/backend/internal/service/api_key_service_cache_test.go +++ b/backend/internal/service/api_key_service_cache_test.go @@ -53,7 +53,7 @@ func (s *authRepoStub) Delete(ctx context.Context, id int64) error { panic("unexpected Delete call") } -func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { +func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) { panic("unexpected ListByUserID call") } @@ -103,6 +103,19 @@ func (s *authRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount panic("unexpected IncrementQuotaUsed call") } +func (s *authRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { + panic("unexpected UpdateLastUsed call") +} +func (s *authRepoStub) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error { + panic("unexpected IncrementRateLimitUsage call") +} +func (s *authRepoStub) ResetRateLimitWindows(ctx context.Context, id int64) error { + panic("unexpected ResetRateLimitWindows call") +} +func (s *authRepoStub) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) { + panic("unexpected GetRateLimitData call") +} + type authCacheStub struct { getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) setAuthKeys []string diff --git a/backend/internal/service/api_key_service_delete_test.go b/backend/internal/service/api_key_service_delete_test.go index d4d12144..dfd481e8 100644 --- a/backend/internal/service/api_key_service_delete_test.go +++ b/backend/internal/service/api_key_service_delete_test.go @@ -24,10 +24,13 @@ import ( // - deleteErr: 模拟 Delete 返回的错误 // - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证 type apiKeyRepoStub struct { - apiKey *APIKey // GetKeyAndOwnerID 的返回值 - getByIDErr error // GetKeyAndOwnerID 的错误返回值 - deleteErr error // Delete 的错误返回值 - deletedIDs []int64 // 记录已删除的 API Key ID 列表 + apiKey *APIKey // GetKeyAndOwnerID 的返回值 + getByIDErr error // GetKeyAndOwnerID 的错误返回值 + deleteErr error // Delete 的错误返回值 + deletedIDs []int64 // 记录已删除的 API Key ID 列表 + updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error + touchedIDs []int64 + touchedUsedAts []time.Time } // 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题 @@ -78,7 +81,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error { // 以下是接口要求实现但本测试不关心的方法 -func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { +func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) { panic("unexpected ListByUserID call") } @@ -122,6 +125,27 @@ func (s *apiKeyRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amoun panic("unexpected IncrementQuotaUsed call") } +func (s *apiKeyRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { + s.touchedIDs = append(s.touchedIDs, id) + s.touchedUsedAts = append(s.touchedUsedAts, usedAt) + if s.updateLastUsed != nil { + return s.updateLastUsed(ctx, id, usedAt) + } + return nil +} + +func (s *apiKeyRepoStub) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error { + panic("unexpected IncrementRateLimitUsage call") +} + +func (s *apiKeyRepoStub) ResetRateLimitWindows(ctx context.Context, id int64) error { + panic("unexpected ResetRateLimitWindows call") +} + +func (s *apiKeyRepoStub) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) { + panic("unexpected GetRateLimitData call") +} + // apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。 // 用于验证删除操作时缓存清理逻辑是否被正确调用。 // @@ -214,12 +238,15 @@ func TestApiKeyService_Delete_Success(t *testing.T) { } cache := &apiKeyCacheStub{} svc := &APIKeyService{apiKeyRepo: repo, cache: cache} + svc.lastUsedTouchL1.Store(int64(42), time.Now()) err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7 require.NoError(t, err) require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除 require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除 require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys) + _, exists := svc.lastUsedTouchL1.Load(int64(42)) + require.False(t, exists, "delete should clear touch debounce cache") } // TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。 diff --git a/backend/internal/service/api_key_service_touch_last_used_test.go b/backend/internal/service/api_key_service_touch_last_used_test.go new file mode 100644 index 00000000..b49bf9ce --- /dev/null +++ b/backend/internal/service/api_key_service_touch_last_used_test.go @@ -0,0 +1,160 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestAPIKeyService_TouchLastUsed_InvalidKeyID(t *testing.T) { + repo := &apiKeyRepoStub{ + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + return errors.New("should not be called") + }, + } + svc := &APIKeyService{apiKeyRepo: repo} + + require.NoError(t, svc.TouchLastUsed(context.Background(), 0)) + require.NoError(t, svc.TouchLastUsed(context.Background(), -1)) + require.Empty(t, repo.touchedIDs) +} + +func TestAPIKeyService_TouchLastUsed_FirstTouchSucceeds(t *testing.T) { + repo := &apiKeyRepoStub{} + svc := &APIKeyService{apiKeyRepo: repo} + + err := svc.TouchLastUsed(context.Background(), 123) + require.NoError(t, err) + require.Equal(t, []int64{123}, repo.touchedIDs) + require.Len(t, repo.touchedUsedAts, 1) + require.False(t, repo.touchedUsedAts[0].IsZero()) + + cached, ok := svc.lastUsedTouchL1.Load(int64(123)) + require.True(t, ok, "successful touch should update debounce cache") + _, isTime := cached.(time.Time) + require.True(t, isTime) +} + +func TestAPIKeyService_TouchLastUsed_DebouncedWithinWindow(t *testing.T) { + repo := &apiKeyRepoStub{} + svc := &APIKeyService{apiKeyRepo: repo} + + require.NoError(t, svc.TouchLastUsed(context.Background(), 123)) + require.NoError(t, svc.TouchLastUsed(context.Background(), 123)) + + require.Equal(t, []int64{123}, repo.touchedIDs, "second touch within debounce window should not hit repository") +} + +func TestAPIKeyService_TouchLastUsed_ExpiredDebounceTouchesAgain(t *testing.T) { + repo := &apiKeyRepoStub{} + svc := &APIKeyService{apiKeyRepo: repo} + + require.NoError(t, svc.TouchLastUsed(context.Background(), 123)) + + // 强制将 debounce 时间回拨到窗口之外,触发第二次写库。 + svc.lastUsedTouchL1.Store(int64(123), time.Now().Add(-apiKeyLastUsedMinTouch-time.Second)) + + require.NoError(t, svc.TouchLastUsed(context.Background(), 123)) + require.Len(t, repo.touchedIDs, 2) + require.Equal(t, int64(123), repo.touchedIDs[0]) + require.Equal(t, int64(123), repo.touchedIDs[1]) +} + +func TestAPIKeyService_TouchLastUsed_RepoError(t *testing.T) { + repo := &apiKeyRepoStub{ + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + return errors.New("db write failed") + }, + } + svc := &APIKeyService{apiKeyRepo: repo} + + err := svc.TouchLastUsed(context.Background(), 123) + require.Error(t, err) + require.ErrorContains(t, err, "touch api key last used") + require.Equal(t, []int64{123}, repo.touchedIDs) + + cached, ok := svc.lastUsedTouchL1.Load(int64(123)) + require.True(t, ok, "failed touch should still update retry debounce cache") + _, isTime := cached.(time.Time) + require.True(t, isTime) +} + +func TestAPIKeyService_TouchLastUsed_RepoErrorDebounced(t *testing.T) { + repo := &apiKeyRepoStub{ + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + return errors.New("db write failed") + }, + } + svc := &APIKeyService{apiKeyRepo: repo} + + firstErr := svc.TouchLastUsed(context.Background(), 456) + require.Error(t, firstErr) + require.ErrorContains(t, firstErr, "touch api key last used") + + secondErr := svc.TouchLastUsed(context.Background(), 456) + require.NoError(t, secondErr, "failed touch should be debounced and skip immediate retry") + require.Equal(t, []int64{456}, repo.touchedIDs, "debounced retry should not hit repository again") +} + +type touchSingleflightRepo struct { + *apiKeyRepoStub + mu sync.Mutex + calls int + blockCh chan struct{} +} + +func (r *touchSingleflightRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { + r.mu.Lock() + r.calls++ + r.mu.Unlock() + <-r.blockCh + return nil +} + +func TestAPIKeyService_TouchLastUsed_ConcurrentFirstTouchDeduplicated(t *testing.T) { + repo := &touchSingleflightRepo{ + apiKeyRepoStub: &apiKeyRepoStub{}, + blockCh: make(chan struct{}), + } + svc := &APIKeyService{apiKeyRepo: repo} + + const workers = 20 + startCh := make(chan struct{}) + errCh := make(chan error, workers) + var wg sync.WaitGroup + + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-startCh + errCh <- svc.TouchLastUsed(context.Background(), 321) + }() + } + + close(startCh) + + require.Eventually(t, func() bool { + repo.mu.Lock() + defer repo.mu.Unlock() + return repo.calls >= 1 + }, time.Second, 10*time.Millisecond) + + close(repo.blockCh) + wg.Wait() + close(errCh) + + for err := range errCh { + require.NoError(t, err) + } + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Equal(t, 1, repo.calls, "并发首次 touch 只应写库一次") +} diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index fb8aaf9c..6a17c83f 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -7,13 +7,14 @@ import ( "encoding/hex" "errors" "fmt" - "log" "net/mail" + "strconv" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/golang-jwt/jwt/v5" "golang.org/x/crypto/bcrypt" @@ -33,6 +34,7 @@ var ( ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired") ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused") ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required") + ErrEmailSuffixNotAllowed = infraerrors.BadRequest("EMAIL_SUFFIX_NOT_ALLOWED", "email suffix is not allowed") ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required") @@ -56,15 +58,20 @@ type JWTClaims struct { // AuthService 认证服务 type AuthService struct { - userRepo UserRepository - redeemRepo RedeemCodeRepository - refreshTokenCache RefreshTokenCache - cfg *config.Config - settingService *SettingService - emailService *EmailService - turnstileService *TurnstileService - emailQueueService *EmailQueueService - promoService *PromoService + userRepo UserRepository + redeemRepo RedeemCodeRepository + refreshTokenCache RefreshTokenCache + cfg *config.Config + settingService *SettingService + emailService *EmailService + turnstileService *TurnstileService + emailQueueService *EmailQueueService + promoService *PromoService + defaultSubAssigner DefaultSubscriptionAssigner +} + +type DefaultSubscriptionAssigner interface { + AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) } // NewAuthService 创建认证服务实例 @@ -78,17 +85,19 @@ func NewAuthService( turnstileService *TurnstileService, emailQueueService *EmailQueueService, promoService *PromoService, + defaultSubAssigner DefaultSubscriptionAssigner, ) *AuthService { return &AuthService{ - userRepo: userRepo, - redeemRepo: redeemRepo, - refreshTokenCache: refreshTokenCache, - cfg: cfg, - settingService: settingService, - emailService: emailService, - turnstileService: turnstileService, - emailQueueService: emailQueueService, - promoService: promoService, + userRepo: userRepo, + redeemRepo: redeemRepo, + refreshTokenCache: refreshTokenCache, + cfg: cfg, + settingService: settingService, + emailService: emailService, + turnstileService: turnstileService, + emailQueueService: emailQueueService, + promoService: promoService, + defaultSubAssigner: defaultSubAssigner, } } @@ -108,6 +117,9 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw if isReservedEmail(email) { return "", nil, ErrEmailReserved } + if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil { + return "", nil, err + } // 检查是否需要邀请码 var invitationRedeemCode *RedeemCode @@ -118,12 +130,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw // 验证邀请码 redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode) if err != nil { - log.Printf("[Auth] Invalid invitation code: %s, error: %v", invitationCode, err) + logger.LegacyPrintf("service.auth", "[Auth] Invalid invitation code: %s, error: %v", invitationCode, err) return "", nil, ErrInvitationCodeInvalid } // 检查类型和状态 if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused { - log.Printf("[Auth] Invitation code invalid: type=%s, status=%s", redeemCode.Type, redeemCode.Status) + logger.LegacyPrintf("service.auth", "[Auth] Invitation code invalid: type=%s, status=%s", redeemCode.Type, redeemCode.Status) return "", nil, ErrInvitationCodeInvalid } invitationRedeemCode = redeemCode @@ -134,7 +146,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw // 如果邮件验证已开启但邮件服务未配置,拒绝注册 // 这是一个配置错误,不应该允许绕过验证 if s.emailService == nil { - log.Println("[Auth] Email verification enabled but email service not configured, rejecting registration") + logger.LegacyPrintf("service.auth", "%s", "[Auth] Email verification enabled but email service not configured, rejecting registration") return "", nil, ErrServiceUnavailable } if verifyCode == "" { @@ -149,7 +161,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { - log.Printf("[Auth] Database error checking email exists: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error checking email exists: %v", err) return "", nil, ErrServiceUnavailable } if existsEmail { @@ -185,22 +197,23 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw if errors.Is(err, ErrEmailExists) { return "", nil, ErrEmailExists } - log.Printf("[Auth] Database error creating user: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err) return "", nil, ErrServiceUnavailable } + s.assignDefaultSubscriptions(ctx, user.ID) // 标记邀请码为已使用(如果使用了邀请码) if invitationRedeemCode != nil { if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { // 邀请码标记失败不影响注册,只记录日志 - log.Printf("[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err) } } // 应用优惠码(如果提供且功能已启用) if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) { if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil { // 优惠码应用失败不影响注册,只记录日志 - log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to apply promo code for user %d: %v", user.ID, err) } else { // 重新获取用户信息以获取更新后的余额 if updatedUser, err := s.userRepo.GetByID(ctx, user.ID); err == nil { @@ -233,11 +246,14 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { if isReservedEmail(email) { return ErrEmailReserved } + if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil { + return err + } // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { - log.Printf("[Auth] Database error checking email exists: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error checking email exists: %v", err) return ErrServiceUnavailable } if existsEmail { @@ -260,32 +276,35 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { // SendVerifyCodeAsync 异步发送邮箱验证码并返回倒计时 func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) { - log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] SendVerifyCodeAsync called for email: %s", email) // 检查是否开放注册(默认关闭) if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { - log.Println("[Auth] Registration is disabled") + logger.LegacyPrintf("service.auth", "%s", "[Auth] Registration is disabled") return nil, ErrRegDisabled } if isReservedEmail(email) { return nil, ErrEmailReserved } + if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil { + return nil, err + } // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { - log.Printf("[Auth] Database error checking email exists: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error checking email exists: %v", err) return nil, ErrServiceUnavailable } if existsEmail { - log.Printf("[Auth] Email already exists: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Email already exists: %s", email) return nil, ErrEmailExists } // 检查邮件队列服务是否配置 if s.emailQueueService == nil { - log.Println("[Auth] Email queue service not configured") + logger.LegacyPrintf("service.auth", "%s", "[Auth] Email queue service not configured") return nil, errors.New("email queue service not configured") } @@ -296,45 +315,56 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S } // 异步发送 - log.Printf("[Auth] Enqueueing verify code for: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Enqueueing verify code for: %s", email) if err := s.emailQueueService.EnqueueVerifyCode(email, siteName); err != nil { - log.Printf("[Auth] Failed to enqueue: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to enqueue: %v", err) return nil, fmt.Errorf("enqueue verify code: %w", err) } - log.Printf("[Auth] Verify code enqueued successfully for: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Verify code enqueued successfully for: %s", email) return &SendVerifyCodeResult{ Countdown: 60, // 60秒倒计时 }, nil } +// VerifyTurnstileForRegister 在注册场景下验证 Turnstile。 +// 当邮箱验证开启且已提交验证码时,说明验证码发送阶段已完成 Turnstile 校验, +// 此处跳过二次校验,避免一次性 token 在注册提交时重复使用导致误报失败。 +func (s *AuthService) VerifyTurnstileForRegister(ctx context.Context, token, remoteIP, verifyCode string) error { + if s.IsEmailVerifyEnabled(ctx) && strings.TrimSpace(verifyCode) != "" { + logger.LegacyPrintf("service.auth", "%s", "[Auth] Email verify flow detected, skip duplicate Turnstile check on register") + return nil + } + return s.VerifyTurnstile(ctx, token, remoteIP) +} + // VerifyTurnstile 验证Turnstile token func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error { required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required if required { if s.settingService == nil { - log.Println("[Auth] Turnstile required but settings service is not configured") + logger.LegacyPrintf("service.auth", "%s", "[Auth] Turnstile required but settings service is not configured") return ErrTurnstileNotConfigured } enabled := s.settingService.IsTurnstileEnabled(ctx) secretConfigured := s.settingService.GetTurnstileSecretKey(ctx) != "" if !enabled || !secretConfigured { - log.Printf("[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured) + logger.LegacyPrintf("service.auth", "[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured) return ErrTurnstileNotConfigured } } if s.turnstileService == nil { if required { - log.Println("[Auth] Turnstile required but service not configured") + logger.LegacyPrintf("service.auth", "%s", "[Auth] Turnstile required but service not configured") return ErrTurnstileNotConfigured } return nil // 服务未配置则跳过验证 } if !required && s.settingService != nil && s.settingService.IsTurnstileEnabled(ctx) && s.settingService.GetTurnstileSecretKey(ctx) == "" { - log.Println("[Auth] Turnstile enabled but secret key not configured") + logger.LegacyPrintf("service.auth", "%s", "[Auth] Turnstile enabled but secret key not configured") } return s.turnstileService.VerifyToken(ctx, token, remoteIP) @@ -373,7 +403,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string return "", nil, ErrInvalidCredentials } // 记录数据库错误但不暴露给用户 - log.Printf("[Auth] Database error during login: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error during login: %v", err) return "", nil, ErrServiceUnavailable } @@ -426,7 +456,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username randomPassword, err := randomHexString(32) if err != nil { - log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to generate random password for oauth signup: %v", err) return "", nil, ErrServiceUnavailable } hashedPassword, err := s.HashPassword(randomPassword) @@ -457,18 +487,19 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username // 并发场景:GetByEmail 与 Create 之间用户被创建。 user, err = s.userRepo.GetByEmail(ctx, email) if err != nil { - log.Printf("[Auth] Database error getting user after conflict: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err) return "", nil, ErrServiceUnavailable } } else { - log.Printf("[Auth] Database error creating oauth user: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err) return "", nil, ErrServiceUnavailable } } else { user = newUser + s.assignDefaultSubscriptions(ctx, user.ID) } } else { - log.Printf("[Auth] Database error during oauth login: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) return "", nil, ErrServiceUnavailable } } @@ -481,7 +512,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username if user.Username == "" && username != "" { user.Username = username if err := s.userRepo.Update(ctx, user); err != nil { - log.Printf("[Auth] Failed to update username after oauth login: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) } } @@ -523,7 +554,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema randomPassword, err := randomHexString(32) if err != nil { - log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to generate random password for oauth signup: %v", err) return nil, nil, ErrServiceUnavailable } hashedPassword, err := s.HashPassword(randomPassword) @@ -552,18 +583,19 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema if errors.Is(err, ErrEmailExists) { user, err = s.userRepo.GetByEmail(ctx, email) if err != nil { - log.Printf("[Auth] Database error getting user after conflict: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err) return nil, nil, ErrServiceUnavailable } } else { - log.Printf("[Auth] Database error creating oauth user: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err) return nil, nil, ErrServiceUnavailable } } else { user = newUser + s.assignDefaultSubscriptions(ctx, user.ID) } } else { - log.Printf("[Auth] Database error during oauth login: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) return nil, nil, ErrServiceUnavailable } } @@ -575,7 +607,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema if user.Username == "" && username != "" { user.Username = username if err := s.userRepo.Update(ctx, user); err != nil { - log.Printf("[Auth] Failed to update username after oauth login: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) } } @@ -586,6 +618,49 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return tokenPair, user, nil } +func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) { + if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { + return + } + items := s.settingService.GetDefaultSubscriptions(ctx) + for _, item := range items { + if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ + UserID: userID, + GroupID: item.GroupID, + ValidityDays: item.ValidityDays, + Notes: "auto assigned by default user subscriptions setting", + }); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err) + } + } +} + +func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error { + if s.settingService == nil { + return nil + } + whitelist := s.settingService.GetRegistrationEmailSuffixWhitelist(ctx) + if !IsRegistrationEmailSuffixAllowed(email, whitelist) { + return buildEmailSuffixNotAllowedError(whitelist) + } + return nil +} + +func buildEmailSuffixNotAllowedError(whitelist []string) error { + if len(whitelist) == 0 { + return ErrEmailSuffixNotAllowed + } + + allowed := strings.Join(whitelist, ", ") + return infraerrors.BadRequest( + "EMAIL_SUFFIX_NOT_ALLOWED", + fmt.Sprintf("email suffix is not allowed, allowed suffixes: %s", allowed), + ).WithMetadata(map[string]string{ + "allowed_suffixes": strings.Join(whitelist, ","), + "allowed_suffix_count": strconv.Itoa(len(whitelist)), + }) +} + // ValidateToken 验证JWT token并返回用户声明 func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { // 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。 @@ -715,7 +790,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) ( if errors.Is(err, ErrUserNotFound) { return "", ErrInvalidToken } - log.Printf("[Auth] Database error refreshing token: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error refreshing token: %v", err) return "", ErrServiceUnavailable } @@ -756,16 +831,16 @@ func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendB if err != nil { if errors.Is(err, ErrUserNotFound) { // Security: Log but don't reveal that user doesn't exist - log.Printf("[Auth] Password reset requested for non-existent email: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Password reset requested for non-existent email: %s", email) return "", "", false } - log.Printf("[Auth] Database error checking email for password reset: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error checking email for password reset: %v", err) return "", "", false } // Check if user is active if !user.IsActive() { - log.Printf("[Auth] Password reset requested for inactive user: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Password reset requested for inactive user: %s", email) return "", "", false } @@ -797,11 +872,11 @@ func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendB } if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil { - log.Printf("[Auth] Failed to send password reset email to %s: %v", email, err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to send password reset email to %s: %v", email, err) return nil // Silent success to prevent enumeration } - log.Printf("[Auth] Password reset email sent to: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Password reset email sent to: %s", email) return nil } @@ -821,11 +896,11 @@ func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, fron } if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL); err != nil { - log.Printf("[Auth] Failed to enqueue password reset email for %s: %v", email, err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to enqueue password reset email for %s: %v", email, err) return nil // Silent success to prevent enumeration } - log.Printf("[Auth] Password reset email enqueued for: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Password reset email enqueued for: %s", email) return nil } @@ -852,7 +927,7 @@ func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPasswo if errors.Is(err, ErrUserNotFound) { return ErrInvalidResetToken // Token was valid but user was deleted } - log.Printf("[Auth] Database error getting user for password reset: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error getting user for password reset: %v", err) return ErrServiceUnavailable } @@ -872,17 +947,17 @@ func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPasswo user.TokenVersion++ // Invalidate all existing tokens if err := s.userRepo.Update(ctx, user); err != nil { - log.Printf("[Auth] Database error updating password for user %d: %v", user.ID, err) + logger.LegacyPrintf("service.auth", "[Auth] Database error updating password for user %d: %v", user.ID, err) return ErrServiceUnavailable } // Also revoke all refresh tokens for this user if err := s.RevokeAllUserSessions(ctx, user.ID); err != nil { - log.Printf("[Auth] Failed to revoke refresh tokens for user %d: %v", user.ID, err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to revoke refresh tokens for user %d: %v", user.ID, err) // Don't return error - password was already changed successfully } - log.Printf("[Auth] Password reset successful for user: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Password reset successful for user: %s", email) return nil } @@ -961,13 +1036,13 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami // 添加到用户Token集合 if err := s.refreshTokenCache.AddToUserTokenSet(ctx, user.ID, tokenHash, ttl); err != nil { - log.Printf("[Auth] Failed to add token to user set: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to add token to user set: %v", err) // 不影响主流程 } // 添加到家族Token集合 if err := s.refreshTokenCache.AddToFamilyTokenSet(ctx, familyID, tokenHash, ttl); err != nil { - log.Printf("[Auth] Failed to add token to family set: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to add token to family set: %v", err) // 不影响主流程 } @@ -994,10 +1069,10 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) if err != nil { if errors.Is(err, ErrRefreshTokenNotFound) { // Token不存在,可能是已被使用(Token轮转)或已过期 - log.Printf("[Auth] Refresh token not found, possible reuse attack") + logger.LegacyPrintf("service.auth", "[Auth] Refresh token not found, possible reuse attack") return nil, ErrRefreshTokenInvalid } - log.Printf("[Auth] Error getting refresh token: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Error getting refresh token: %v", err) return nil, ErrServiceUnavailable } @@ -1016,7 +1091,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) _ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID) return nil, ErrRefreshTokenInvalid } - log.Printf("[Auth] Database error getting user for token refresh: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error getting user for token refresh: %v", err) return nil, ErrServiceUnavailable } @@ -1036,7 +1111,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) // Token轮转:立即使旧Token失效 if err := s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash); err != nil { - log.Printf("[Auth] Failed to delete old refresh token: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to delete old refresh token: %v", err) // 继续处理,不影响主流程 } diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index f1685be5..b139fdcd 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/stretchr/testify/require" ) @@ -56,6 +57,21 @@ type emailCacheStub struct { err error } +type defaultSubscriptionAssignerStub struct { + calls []AssignSubscriptionInput + err error +} + +func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) { + if input != nil { + s.calls = append(s.calls, *input) + } + if s.err != nil { + return nil, false, s.err + } + return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil +} + func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) { if s.err != nil { return nil, s.err @@ -123,6 +139,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E nil, nil, nil, // promoService + nil, // defaultSubAssigner ) } @@ -215,6 +232,51 @@ func TestAuthService_Register_ReservedEmail(t *testing.T) { require.ErrorIs(t, err, ErrEmailReserved) } +func TestAuthService_Register_EmailSuffixNotAllowed(t *testing.T) { + repo := &userRepoStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`, + }, nil) + + _, _, err := service.Register(context.Background(), "user@other.com", "password") + require.ErrorIs(t, err, ErrEmailSuffixNotAllowed) + appErr := infraerrors.FromError(err) + require.Contains(t, appErr.Message, "@example.com") + require.Contains(t, appErr.Message, "@company.com") + require.Equal(t, "EMAIL_SUFFIX_NOT_ALLOWED", appErr.Reason) + require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"]) + require.Equal(t, "@example.com,@company.com", appErr.Metadata["allowed_suffixes"]) +} + +func TestAuthService_Register_EmailSuffixAllowed(t *testing.T) { + repo := &userRepoStub{nextID: 8} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyRegistrationEmailSuffixWhitelist: `["example.com"]`, + }, nil) + + _, user, err := service.Register(context.Background(), "user@example.com", "password") + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, int64(8), user.ID) +} + +func TestAuthService_SendVerifyCode_EmailSuffixNotAllowed(t *testing.T) { + repo := &userRepoStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`, + }, nil) + + err := service.SendVerifyCode(context.Background(), "user@other.com") + require.ErrorIs(t, err, ErrEmailSuffixNotAllowed) + appErr := infraerrors.FromError(err) + require.Contains(t, appErr.Message, "@example.com") + require.Contains(t, appErr.Message, "@company.com") + require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"]) +} + func TestAuthService_Register_CreateError(t *testing.T) { repo := &userRepoStub{createErr: errors.New("create failed")} service := newAuthService(repo, map[string]string{ @@ -315,3 +377,89 @@ func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) { require.NotEmpty(t, newToken) }) } + +func TestAuthService_GetAccessTokenExpiresIn_FallbackToExpireHour(t *testing.T) { + service := newAuthService(&userRepoStub{}, nil, nil) + service.cfg.JWT.ExpireHour = 24 + service.cfg.JWT.AccessTokenExpireMinutes = 0 + + require.Equal(t, 24*3600, service.GetAccessTokenExpiresIn()) +} + +func TestAuthService_GetAccessTokenExpiresIn_MinutesHasPriority(t *testing.T) { + service := newAuthService(&userRepoStub{}, nil, nil) + service.cfg.JWT.ExpireHour = 24 + service.cfg.JWT.AccessTokenExpireMinutes = 90 + + require.Equal(t, 90*60, service.GetAccessTokenExpiresIn()) +} + +func TestAuthService_GenerateToken_UsesExpireHourWhenMinutesZero(t *testing.T) { + service := newAuthService(&userRepoStub{}, nil, nil) + service.cfg.JWT.ExpireHour = 24 + service.cfg.JWT.AccessTokenExpireMinutes = 0 + + user := &User{ + ID: 1, + Email: "test@test.com", + Role: RoleUser, + Status: StatusActive, + TokenVersion: 1, + } + + token, err := service.GenerateToken(user) + require.NoError(t, err) + + claims, err := service.ValidateToken(token) + require.NoError(t, err) + require.NotNil(t, claims) + require.NotNil(t, claims.IssuedAt) + require.NotNil(t, claims.ExpiresAt) + + require.WithinDuration(t, claims.IssuedAt.Time.Add(24*time.Hour), claims.ExpiresAt.Time, 2*time.Second) +} + +func TestAuthService_GenerateToken_UsesMinutesWhenConfigured(t *testing.T) { + service := newAuthService(&userRepoStub{}, nil, nil) + service.cfg.JWT.ExpireHour = 24 + service.cfg.JWT.AccessTokenExpireMinutes = 90 + + user := &User{ + ID: 2, + Email: "test2@test.com", + Role: RoleUser, + Status: StatusActive, + TokenVersion: 1, + } + + token, err := service.GenerateToken(user) + require.NoError(t, err) + + claims, err := service.ValidateToken(token) + require.NoError(t, err) + require.NotNil(t, claims) + require.NotNil(t, claims.IssuedAt) + require.NotNil(t, claims.ExpiresAt) + + require.WithinDuration(t, claims.IssuedAt.Time.Add(90*time.Minute), claims.ExpiresAt.Time, 2*time.Second) +} + +func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) { + repo := &userRepoStub{nextID: 42} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`, + }, nil) + service.defaultSubAssigner = assigner + + _, user, err := service.Register(context.Background(), "default-sub@test.com", "password") + require.NoError(t, err) + require.NotNil(t, user) + require.Len(t, assigner.calls, 2) + require.Equal(t, int64(42), assigner.calls[0].UserID) + require.Equal(t, int64(11), assigner.calls[0].GroupID) + require.Equal(t, 30, assigner.calls[0].ValidityDays) + require.Equal(t, int64(12), assigner.calls[1].GroupID) + require.Equal(t, 7, assigner.calls[1].ValidityDays) +} diff --git a/backend/internal/service/auth_service_turnstile_register_test.go b/backend/internal/service/auth_service_turnstile_register_test.go new file mode 100644 index 00000000..36cb1e06 --- /dev/null +++ b/backend/internal/service/auth_service_turnstile_register_test.go @@ -0,0 +1,97 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type turnstileVerifierSpy struct { + called int + lastToken string + result *TurnstileVerifyResponse + err error +} + +func (s *turnstileVerifierSpy) VerifyToken(_ context.Context, _ string, token, _ string) (*TurnstileVerifyResponse, error) { + s.called++ + s.lastToken = token + if s.err != nil { + return nil, s.err + } + if s.result != nil { + return s.result, nil + } + return &TurnstileVerifyResponse{Success: true}, nil +} + +func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier TurnstileVerifier) *AuthService { + cfg := &config.Config{ + Server: config.ServerConfig{ + Mode: "release", + }, + Turnstile: config.TurnstileConfig{ + Required: true, + }, + } + + settingService := NewSettingService(&settingRepoStub{values: settings}, cfg) + turnstileService := NewTurnstileService(settingService, verifier) + + return NewAuthService( + &userRepoStub{}, + nil, // redeemRepo + nil, // refreshTokenCache + cfg, + settingService, + nil, // emailService + turnstileService, + nil, // emailQueueService + nil, // promoService + nil, // defaultSubAssigner + ) +} + +func TestAuthService_VerifyTurnstileForRegister_SkipWhenEmailVerifyCodeProvided(t *testing.T) { + verifier := &turnstileVerifierSpy{} + service := newAuthServiceForRegisterTurnstileTest(map[string]string{ + SettingKeyEmailVerifyEnabled: "true", + SettingKeyTurnstileEnabled: "true", + SettingKeyTurnstileSecretKey: "secret", + SettingKeyRegistrationEnabled: "true", + }, verifier) + + err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "123456") + require.NoError(t, err) + require.Equal(t, 0, verifier.called) +} + +func TestAuthService_VerifyTurnstileForRegister_RequireWhenVerifyCodeMissing(t *testing.T) { + verifier := &turnstileVerifierSpy{} + service := newAuthServiceForRegisterTurnstileTest(map[string]string{ + SettingKeyEmailVerifyEnabled: "true", + SettingKeyTurnstileEnabled: "true", + SettingKeyTurnstileSecretKey: "secret", + }, verifier) + + err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "") + require.ErrorIs(t, err, ErrTurnstileVerificationFailed) +} + +func TestAuthService_VerifyTurnstileForRegister_NoSkipWhenEmailVerifyDisabled(t *testing.T) { + verifier := &turnstileVerifierSpy{} + service := newAuthServiceForRegisterTurnstileTest(map[string]string{ + SettingKeyEmailVerifyEnabled: "false", + SettingKeyTurnstileEnabled: "true", + SettingKeyTurnstileSecretKey: "secret", + }, verifier) + + err := service.VerifyTurnstileForRegister(context.Background(), "turnstile-token", "127.0.0.1", "123456") + require.NoError(t, err) + require.Equal(t, 1, verifier.called) + require.Equal(t, "turnstile-token", verifier.lastToken) +} diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index c09cafb9..e055c0f7 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -3,13 +3,15 @@ package service import ( "context" "fmt" - "log" + "strconv" "sync" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "golang.org/x/sync/singleflight" ) // 错误定义 @@ -38,6 +40,7 @@ const ( cacheWriteSetSubscription cacheWriteUpdateSubscriptionUsage cacheWriteDeductBalance + cacheWriteUpdateRateLimitUsage ) // 异步缓存写入工作池配置 @@ -58,6 +61,7 @@ const ( cacheWriteBufferSize = 1000 // 任务队列缓冲大小 cacheWriteTimeout = 2 * time.Second // 单个写入操作超时 cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔 + balanceLoadTimeout = 3 * time.Second ) // cacheWriteTask 缓存写入任务 @@ -65,23 +69,33 @@ type cacheWriteTask struct { kind cacheWriteKind userID int64 groupID int64 + apiKeyID int64 balance float64 amount float64 subscriptionData *subscriptionCacheData } +// apiKeyRateLimitLoader defines the interface for loading rate limit data from DB. +type apiKeyRateLimitLoader interface { + GetRateLimitData(ctx context.Context, keyID int64) (*APIKeyRateLimitData, error) +} + // BillingCacheService 计费缓存服务 // 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查 type BillingCacheService struct { - cache BillingCache - userRepo UserRepository - subRepo UserSubscriptionRepository - cfg *config.Config - circuitBreaker *billingCircuitBreaker + cache BillingCache + userRepo UserRepository + subRepo UserSubscriptionRepository + apiKeyRateLimitLoader apiKeyRateLimitLoader + cfg *config.Config + circuitBreaker *billingCircuitBreaker cacheWriteChan chan cacheWriteTask cacheWriteWg sync.WaitGroup cacheWriteStopOnce sync.Once + cacheWriteMu sync.RWMutex + stopped atomic.Bool + balanceLoadSF singleflight.Group // 丢弃日志节流计数器(减少高负载下日志噪音) cacheWriteDropFullCount uint64 cacheWriteDropFullLastLog int64 @@ -90,12 +104,13 @@ type BillingCacheService struct { } // NewBillingCacheService 创建计费缓存服务 -func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService { +func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, cfg *config.Config) *BillingCacheService { svc := &BillingCacheService{ - cache: cache, - userRepo: userRepo, - subRepo: subRepo, - cfg: cfg, + cache: cache, + userRepo: userRepo, + subRepo: subRepo, + apiKeyRateLimitLoader: apiKeyRepo, + cfg: cfg, } svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker) svc.startCacheWriteWorkers() @@ -105,35 +120,52 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo // Stop 关闭缓存写入工作池 func (s *BillingCacheService) Stop() { s.cacheWriteStopOnce.Do(func() { - if s.cacheWriteChan == nil { + s.stopped.Store(true) + + s.cacheWriteMu.Lock() + ch := s.cacheWriteChan + if ch != nil { + close(ch) + } + s.cacheWriteMu.Unlock() + + if ch == nil { return } - close(s.cacheWriteChan) s.cacheWriteWg.Wait() - s.cacheWriteChan = nil + + s.cacheWriteMu.Lock() + if s.cacheWriteChan == ch { + s.cacheWriteChan = nil + } + s.cacheWriteMu.Unlock() }) } func (s *BillingCacheService) startCacheWriteWorkers() { - s.cacheWriteChan = make(chan cacheWriteTask, cacheWriteBufferSize) + ch := make(chan cacheWriteTask, cacheWriteBufferSize) + s.cacheWriteChan = ch for i := 0; i < cacheWriteWorkerCount; i++ { s.cacheWriteWg.Add(1) - go s.cacheWriteWorker() + go s.cacheWriteWorker(ch) } } // enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。 func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) { - if s.cacheWriteChan == nil { + if s.stopped.Load() { + s.logCacheWriteDrop(task, "closed") return false } - defer func() { - if recovered := recover(); recovered != nil { - // 队列已关闭时可能触发 panic,记录后静默失败。 - s.logCacheWriteDrop(task, "closed") - enqueued = false - } - }() + + s.cacheWriteMu.RLock() + defer s.cacheWriteMu.RUnlock() + + if s.cacheWriteChan == nil { + s.logCacheWriteDrop(task, "closed") + return false + } + select { case s.cacheWriteChan <- task: return true @@ -144,9 +176,9 @@ func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued b } } -func (s *BillingCacheService) cacheWriteWorker() { +func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) { defer s.cacheWriteWg.Done() - for task := range s.cacheWriteChan { + for task := range ch { ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) switch task.kind { case cacheWriteSetBalance: @@ -156,13 +188,19 @@ func (s *BillingCacheService) cacheWriteWorker() { case cacheWriteUpdateSubscriptionUsage: if s.cache != nil { if err := s.cache.UpdateSubscriptionUsage(ctx, task.userID, task.groupID, task.amount); err != nil { - log.Printf("Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err) } } case cacheWriteDeductBalance: if s.cache != nil { if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil { - log.Printf("Warning: deduct balance cache failed for user %d: %v", task.userID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache failed for user %d: %v", task.userID, err) + } + } + case cacheWriteUpdateRateLimitUsage: + if s.cache != nil { + if err := s.cache.UpdateAPIKeyRateLimitUsage(ctx, task.apiKeyID, task.amount); err != nil { + logger.LegacyPrintf("service.billing_cache", "Warning: update rate limit usage cache failed for api key %d: %v", task.apiKeyID, err) } } } @@ -181,6 +219,8 @@ func cacheWriteKindName(kind cacheWriteKind) string { return "update_subscription_usage" case cacheWriteDeductBalance: return "deduct_balance" + case cacheWriteUpdateRateLimitUsage: + return "update_rate_limit_usage" default: return "unknown" } @@ -216,7 +256,7 @@ func (s *BillingCacheService) logCacheWriteDrop(task cacheWriteTask, reason stri if dropped == 0 { return } - log.Printf("Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)", + logger.LegacyPrintf("service.billing_cache", "Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)", reason, dropped, cacheWriteDropLogInterval, @@ -243,19 +283,31 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) return balance, nil } - // 缓存未命中,从数据库读取 - balance, err = s.getUserBalanceFromDB(ctx, userID) + // 缓存未命中:singleflight 合并同一 userID 的并发回源请求。 + value, err, _ := s.balanceLoadSF.Do(strconv.FormatInt(userID, 10), func() (any, error) { + loadCtx, cancel := context.WithTimeout(context.Background(), balanceLoadTimeout) + defer cancel() + + balance, err := s.getUserBalanceFromDB(loadCtx, userID) + if err != nil { + return nil, err + } + + // 异步建立缓存 + _ = s.enqueueCacheWrite(cacheWriteTask{ + kind: cacheWriteSetBalance, + userID: userID, + balance: balance, + }) + return balance, nil + }) if err != nil { return 0, err } - - // 异步建立缓存 - _ = s.enqueueCacheWrite(cacheWriteTask{ - kind: cacheWriteSetBalance, - userID: userID, - balance: balance, - }) - + balance, ok := value.(float64) + if !ok { + return 0, fmt.Errorf("unexpected balance type: %T", value) + } return balance, nil } @@ -274,7 +326,7 @@ func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, return } if err := s.cache.SetUserBalance(ctx, userID, balance); err != nil { - log.Printf("Warning: set balance cache failed for user %d: %v", userID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: set balance cache failed for user %d: %v", userID, err) } } @@ -302,7 +354,7 @@ func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) { ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) defer cancel() if err := s.DeductBalanceCache(ctx, userID, amount); err != nil { - log.Printf("Warning: deduct balance cache fallback failed for user %d: %v", userID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache fallback failed for user %d: %v", userID, err) } } @@ -312,7 +364,7 @@ func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID return nil } if err := s.cache.InvalidateUserBalance(ctx, userID); err != nil { - log.Printf("Warning: invalidate balance cache failed for user %d: %v", userID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: invalidate balance cache failed for user %d: %v", userID, err) return err } return nil @@ -396,7 +448,7 @@ func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, return } if err := s.cache.SetSubscriptionCache(ctx, userID, groupID, s.convertToPortsData(data)); err != nil { - log.Printf("Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err) } } @@ -425,7 +477,7 @@ func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64 ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) defer cancel() if err := s.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD); err != nil { - log.Printf("Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err) } } @@ -435,12 +487,143 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID return nil } if err := s.cache.InvalidateSubscriptionCache(ctx, userID, groupID); err != nil { - log.Printf("Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err) return err } return nil } +// ============================================ +// API Key 限速缓存方法 +// ============================================ + +// checkAPIKeyRateLimits checks rate limit windows for an API key. +// It loads usage from Redis cache (falling back to DB on cache miss), +// resets expired windows in-memory and triggers async DB reset, +// and returns an error if any window limit is exceeded. +func (s *BillingCacheService) checkAPIKeyRateLimits(ctx context.Context, apiKey *APIKey) error { + if s.cache == nil { + // No cache: fall back to reading from DB directly + if s.apiKeyRateLimitLoader == nil { + return nil + } + data, err := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID) + if err != nil { + return nil // Don't block requests on DB errors + } + return s.evaluateRateLimits(ctx, apiKey, data.Usage5h, data.Usage1d, data.Usage7d, + data.Window5hStart, data.Window1dStart, data.Window7dStart) + } + + cacheData, err := s.cache.GetAPIKeyRateLimit(ctx, apiKey.ID) + if err != nil { + // Cache miss: load from DB and populate cache + if s.apiKeyRateLimitLoader == nil { + return nil + } + dbData, dbErr := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID) + if dbErr != nil { + return nil // Don't block requests on DB errors + } + // Build cache entry from DB data + cacheEntry := &APIKeyRateLimitCacheData{ + Usage5h: dbData.Usage5h, + Usage1d: dbData.Usage1d, + Usage7d: dbData.Usage7d, + } + if dbData.Window5hStart != nil { + cacheEntry.Window5h = dbData.Window5hStart.Unix() + } + if dbData.Window1dStart != nil { + cacheEntry.Window1d = dbData.Window1dStart.Unix() + } + if dbData.Window7dStart != nil { + cacheEntry.Window7d = dbData.Window7dStart.Unix() + } + _ = s.cache.SetAPIKeyRateLimit(ctx, apiKey.ID, cacheEntry) + cacheData = cacheEntry + } + + var w5h, w1d, w7d *time.Time + if cacheData.Window5h > 0 { + t := time.Unix(cacheData.Window5h, 0) + w5h = &t + } + if cacheData.Window1d > 0 { + t := time.Unix(cacheData.Window1d, 0) + w1d = &t + } + if cacheData.Window7d > 0 { + t := time.Unix(cacheData.Window7d, 0) + w7d = &t + } + return s.evaluateRateLimits(ctx, apiKey, cacheData.Usage5h, cacheData.Usage1d, cacheData.Usage7d, w5h, w1d, w7d) +} + +// evaluateRateLimits checks usage against limits, triggering async resets for expired windows. +func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *APIKey, usage5h, usage1d, usage7d float64, w5h, w1d, w7d *time.Time) error { + needsReset := false + + // Reset expired windows in-memory for check purposes + if w5h != nil && time.Since(*w5h) >= 5*time.Hour { + usage5h = 0 + needsReset = true + } + if w1d != nil && time.Since(*w1d) >= 24*time.Hour { + usage1d = 0 + needsReset = true + } + if w7d != nil && time.Since(*w7d) >= 7*24*time.Hour { + usage7d = 0 + needsReset = true + } + + // Trigger async DB reset if any window expired + if needsReset { + keyID := apiKey.ID + go func() { + resetCtx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) + defer cancel() + if s.apiKeyRateLimitLoader != nil { + // Use the repo directly - reset then reload cache + if loader, ok := s.apiKeyRateLimitLoader.(interface { + ResetRateLimitWindows(ctx context.Context, id int64) error + }); ok { + _ = loader.ResetRateLimitWindows(resetCtx, keyID) + } + } + // Invalidate cache so next request loads fresh data + if s.cache != nil { + _ = s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID) + } + }() + } + + // Check limits + if apiKey.RateLimit5h > 0 && usage5h >= apiKey.RateLimit5h { + return ErrAPIKeyRateLimit5hExceeded + } + if apiKey.RateLimit1d > 0 && usage1d >= apiKey.RateLimit1d { + return ErrAPIKeyRateLimit1dExceeded + } + if apiKey.RateLimit7d > 0 && usage7d >= apiKey.RateLimit7d { + return ErrAPIKeyRateLimit7dExceeded + } + return nil +} + +// QueueUpdateAPIKeyRateLimitUsage asynchronously updates rate limit usage in the cache. +func (s *BillingCacheService) QueueUpdateAPIKeyRateLimitUsage(apiKeyID int64, cost float64) { + if s.cache == nil { + return + } + s.enqueueCacheWrite(cacheWriteTask{ + kind: cacheWriteUpdateRateLimitUsage, + apiKeyID: apiKeyID, + amount: cost, + }) +} + // ============================================ // 统一检查方法 // ============================================ @@ -461,10 +644,23 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil if isSubscriptionMode { - return s.checkSubscriptionEligibility(ctx, user.ID, group, subscription) + if err := s.checkSubscriptionEligibility(ctx, user.ID, group, subscription); err != nil { + return err + } + } else { + if err := s.checkBalanceEligibility(ctx, user.ID); err != nil { + return err + } } - return s.checkBalanceEligibility(ctx, user.ID) + // Check API Key rate limits (applies to both billing modes) + if apiKey != nil && apiKey.HasRateLimits() { + if err := s.checkAPIKeyRateLimits(ctx, apiKey); err != nil { + return err + } + } + + return nil } // checkBalanceEligibility 检查余额模式资格 @@ -474,7 +670,7 @@ func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userI if s.circuitBreaker != nil { s.circuitBreaker.OnFailure(err) } - log.Printf("ALERT: billing balance check failed for user %d: %v", userID, err) + logger.LegacyPrintf("service.billing_cache", "ALERT: billing balance check failed for user %d: %v", userID, err) return ErrBillingServiceUnavailable.WithCause(err) } if s.circuitBreaker != nil { @@ -496,7 +692,7 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, if s.circuitBreaker != nil { s.circuitBreaker.OnFailure(err) } - log.Printf("ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err) + logger.LegacyPrintf("service.billing_cache", "ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err) return ErrBillingServiceUnavailable.WithCause(err) } if s.circuitBreaker != nil { @@ -585,7 +781,7 @@ func (b *billingCircuitBreaker) Allow() bool { } b.state = billingCircuitHalfOpen b.halfOpenRemaining = b.halfOpenRequests - log.Printf("ALERT: billing circuit breaker entering half-open state") + logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker entering half-open state") fallthrough case billingCircuitHalfOpen: if b.halfOpenRemaining <= 0 { @@ -612,7 +808,7 @@ func (b *billingCircuitBreaker) OnFailure(err error) { b.state = billingCircuitOpen b.openedAt = time.Now() b.halfOpenRemaining = 0 - log.Printf("ALERT: billing circuit breaker opened after half-open failure: %v", err) + logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker opened after half-open failure: %v", err) return default: b.failures++ @@ -620,7 +816,7 @@ func (b *billingCircuitBreaker) OnFailure(err error) { b.state = billingCircuitOpen b.openedAt = time.Now() b.halfOpenRemaining = 0 - log.Printf("ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err) + logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err) } } } @@ -641,9 +837,9 @@ func (b *billingCircuitBreaker) OnSuccess() { // 只有状态真正发生变化时才记录日志 if previousState != billingCircuitClosed { - log.Printf("ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState)) + logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState)) } else if previousFailures > 0 { - log.Printf("INFO: billing circuit breaker failures reset from %d", previousFailures) + logger.LegacyPrintf("service.billing_cache", "INFO: billing circuit breaker failures reset from %d", previousFailures) } } diff --git a/backend/internal/service/billing_cache_service_singleflight_test.go b/backend/internal/service/billing_cache_service_singleflight_test.go new file mode 100644 index 00000000..4a8b8f03 --- /dev/null +++ b/backend/internal/service/billing_cache_service_singleflight_test.go @@ -0,0 +1,131 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type billingCacheMissStub struct { + setBalanceCalls atomic.Int64 +} + +func (s *billingCacheMissStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) { + return 0, errors.New("cache miss") +} + +func (s *billingCacheMissStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error { + s.setBalanceCalls.Add(1) + return nil +} + +func (s *billingCacheMissStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { + return nil +} + +func (s *billingCacheMissStub) InvalidateUserBalance(ctx context.Context, userID int64) error { + return nil +} + +func (s *billingCacheMissStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) { + return nil, errors.New("cache miss") +} + +func (s *billingCacheMissStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error { + return nil +} + +func (s *billingCacheMissStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error { + return nil +} + +func (s *billingCacheMissStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error { + return nil +} + +func (s *billingCacheMissStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) { + return nil, errors.New("cache miss") +} + +func (s *billingCacheMissStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error { + return nil +} + +func (s *billingCacheMissStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error { + return nil +} + +func (s *billingCacheMissStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error { + return nil +} + +type balanceLoadUserRepoStub struct { + mockUserRepo + calls atomic.Int64 + delay time.Duration + balance float64 +} + +func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User, error) { + s.calls.Add(1) + if s.delay > 0 { + select { + case <-time.After(s.delay): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + return &User{ID: id, Balance: s.balance}, nil +} + +func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) { + cache := &billingCacheMissStub{} + userRepo := &balanceLoadUserRepoStub{ + delay: 80 * time.Millisecond, + balance: 12.34, + } + svc := NewBillingCacheService(cache, userRepo, nil, nil, &config.Config{}) + t.Cleanup(svc.Stop) + + const goroutines = 16 + start := make(chan struct{}) + var wg sync.WaitGroup + errCh := make(chan error, goroutines) + balCh := make(chan float64, goroutines) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + bal, err := svc.GetUserBalance(context.Background(), 99) + errCh <- err + balCh <- bal + }() + } + + close(start) + wg.Wait() + close(errCh) + close(balCh) + + for err := range errCh { + require.NoError(t, err) + } + for bal := range balCh { + require.Equal(t, 12.34, bal) + } + + require.Equal(t, int64(1), userRepo.calls.Load(), "并发穿透应被 singleflight 合并") + require.Eventually(t, func() bool { + return cache.setBalanceCalls.Load() >= 1 + }, time.Second, 10*time.Millisecond) +} diff --git a/backend/internal/service/billing_cache_service_test.go b/backend/internal/service/billing_cache_service_test.go index 445d5319..7d7045e2 100644 --- a/backend/internal/service/billing_cache_service_test.go +++ b/backend/internal/service/billing_cache_service_test.go @@ -52,9 +52,25 @@ func (b *billingCacheWorkerStub) InvalidateSubscriptionCache(ctx context.Context return nil } +func (b *billingCacheWorkerStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) { + return nil, errors.New("not implemented") +} + +func (b *billingCacheWorkerStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error { + return nil +} + +func (b *billingCacheWorkerStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error { + return nil +} + +func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error { + return nil +} + func TestBillingCacheServiceQueueHighLoad(t *testing.T) { cache := &billingCacheWorkerStub{} - svc := NewBillingCacheService(cache, nil, nil, &config.Config{}) + svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{}) t.Cleanup(svc.Stop) start := time.Now() @@ -73,3 +89,16 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) { return atomic.LoadInt64(&cache.subscriptionUpdates) > 0 }, 2*time.Second, 10*time.Millisecond) } + +func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) { + cache := &billingCacheWorkerStub{} + svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{}) + svc.Stop() + + enqueued := svc.enqueueCacheWrite(cacheWriteTask{ + kind: cacheWriteDeductBalance, + userID: 1, + amount: 1, + }) + require.False(t, enqueued) +} diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index db5a9708..5d67c808 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -10,6 +10,16 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" ) +// APIKeyRateLimitCacheData holds rate limit usage data cached in Redis. +type APIKeyRateLimitCacheData struct { + Usage5h float64 `json:"usage_5h"` + Usage1d float64 `json:"usage_1d"` + Usage7d float64 `json:"usage_7d"` + Window5h int64 `json:"window_5h"` // unix timestamp, 0 = not started + Window1d int64 `json:"window_1d"` + Window7d int64 `json:"window_7d"` +} + // BillingCache defines cache operations for billing service type BillingCache interface { // Balance operations @@ -23,6 +33,12 @@ type BillingCache interface { SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error + + // API Key rate limit operations + GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) + SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error + UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error + InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error } // ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致) @@ -31,8 +47,8 @@ type ModelPricing struct { OutputPricePerToken float64 // 每token输出价格 (USD) CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD) CacheReadPricePerToken float64 // 缓存读取每token价格 (USD) - CacheCreation5mPrice float64 // 5分钟缓存创建价格(每百万token)- 仅用于硬编码回退 - CacheCreation1hPrice float64 // 1小时缓存创建价格(每百万token)- 仅用于硬编码回退 + CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD) + CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD) SupportsCacheBreakdown bool // 是否支持详细的缓存分类 } @@ -133,6 +149,18 @@ func (s *BillingService) initFallbackPricing() { CacheReadPricePerToken: 0.03e-6, // $0.03 per MTok SupportsCacheBreakdown: false, } + + // Claude 4.6 Opus (与4.5同价) + s.fallbackPrices["claude-opus-4.6"] = s.fallbackPrices["claude-opus-4.5"] + + // Gemini 3.1 Pro + s.fallbackPrices["gemini-3.1-pro"] = &ModelPricing{ + InputPricePerToken: 2e-6, // $2 per MTok + OutputPricePerToken: 12e-6, // $12 per MTok + CacheCreationPricePerToken: 2e-6, // $2 per MTok + CacheReadPricePerToken: 0.2e-6, // $0.20 per MTok + SupportsCacheBreakdown: false, + } } // getFallbackPricing 根据模型系列获取回退价格 @@ -141,6 +169,9 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { // 按模型系列匹配 if strings.Contains(modelLower, "opus") { + if strings.Contains(modelLower, "4.6") || strings.Contains(modelLower, "4-6") { + return s.fallbackPrices["claude-opus-4.6"] + } if strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5") { return s.fallbackPrices["claude-opus-4.5"] } @@ -158,6 +189,9 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { } return s.fallbackPrices["claude-3-haiku"] } + if strings.Contains(modelLower, "gemini-3.1-pro") || strings.Contains(modelLower, "gemini-3-1-pro") { + return s.fallbackPrices["gemini-3.1-pro"] + } // 默认使用Sonnet价格 return s.fallbackPrices["claude-sonnet-4"] @@ -172,12 +206,20 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) { if s.pricingService != nil { litellmPricing := s.pricingService.GetModelPricing(model) if litellmPricing != nil { + // 启用 5m/1h 分类计费的条件: + // 1. 存在 1h 价格 + // 2. 1h 价格 > 5m 价格(防止 LiteLLM 数据错误导致少收费) + price5m := litellmPricing.CacheCreationInputTokenCost + price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr + enableBreakdown := price1h > 0 && price1h > price5m return &ModelPricing{ InputPricePerToken: litellmPricing.InputCostPerToken, OutputPricePerToken: litellmPricing.OutputCostPerToken, CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost, CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost, - SupportsCacheBreakdown: false, + CacheCreation5mPrice: price5m, + CacheCreation1hPrice: price1h, + SupportsCacheBreakdown: enableBreakdown, }, nil } } @@ -209,9 +251,14 @@ func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMul // 计算缓存费用 if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) { - // 支持详细缓存分类的模型(5分钟/1小时缓存) - breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)/1_000_000*pricing.CacheCreation5mPrice + - float64(tokens.CacheCreation1hTokens)/1_000_000*pricing.CacheCreation1hPrice + // 支持详细缓存分类的模型(5分钟/1小时缓存,价格为 per-token) + if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 { + // API 未返回 ephemeral 明细,回退到全部按 5m 单价计费 + breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice + } else { + breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice + + float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice + } } else { // 标准缓存创建价格(per-token) breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken @@ -280,10 +327,12 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage // 范围内部分:正常计费 inRangeTokens := UsageTokens{ - InputTokens: inRangeInputTokens, - OutputTokens: tokens.OutputTokens, // 输出只算一次 - CacheCreationTokens: tokens.CacheCreationTokens, - CacheReadTokens: inRangeCacheTokens, + InputTokens: inRangeInputTokens, + OutputTokens: tokens.OutputTokens, // 输出只算一次 + CacheCreationTokens: tokens.CacheCreationTokens, + CacheReadTokens: inRangeCacheTokens, + CacheCreation5mTokens: tokens.CacheCreation5mTokens, + CacheCreation1hTokens: tokens.CacheCreation1hTokens, } inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier) if err != nil { @@ -297,7 +346,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage } outRangeCost, err := s.CalculateCost(model, outRangeTokens, rateMultiplier*extraMultiplier) if err != nil { - return inRangeCost, nil // 出错时返回范围内成本 + return inRangeCost, fmt.Errorf("out-range cost: %w", err) } // 合并成本 @@ -373,6 +422,14 @@ type ImagePriceConfig struct { Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值) } +// SoraPriceConfig Sora 按次计费配置 +type SoraPriceConfig struct { + ImagePrice360 *float64 + ImagePrice540 *float64 + VideoPricePerRequest *float64 + VideoPricePerRequestHD *float64 +} + // CalculateImageCost 计算图片生成费用 // model: 请求的模型名称(用于获取 LiteLLM 默认价格) // imageSize: 图片尺寸 "1K", "2K", "4K" @@ -402,6 +459,65 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag } } +// CalculateSoraImageCost 计算 Sora 图片按次费用 +func (s *BillingService) CalculateSoraImageCost(imageSize string, imageCount int, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown { + if imageCount <= 0 { + return &CostBreakdown{} + } + + unitPrice := 0.0 + if groupConfig != nil { + switch imageSize { + case "540": + if groupConfig.ImagePrice540 != nil { + unitPrice = *groupConfig.ImagePrice540 + } + default: + if groupConfig.ImagePrice360 != nil { + unitPrice = *groupConfig.ImagePrice360 + } + } + } + + totalCost := unitPrice * float64(imageCount) + if rateMultiplier <= 0 { + rateMultiplier = 1.0 + } + actualCost := totalCost * rateMultiplier + + return &CostBreakdown{ + TotalCost: totalCost, + ActualCost: actualCost, + } +} + +// CalculateSoraVideoCost 计算 Sora 视频按次费用 +func (s *BillingService) CalculateSoraVideoCost(model string, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown { + unitPrice := 0.0 + if groupConfig != nil { + modelLower := strings.ToLower(model) + if strings.Contains(modelLower, "sora2pro-hd") { + if groupConfig.VideoPricePerRequestHD != nil { + unitPrice = *groupConfig.VideoPricePerRequestHD + } + } + if unitPrice <= 0 && groupConfig.VideoPricePerRequest != nil { + unitPrice = *groupConfig.VideoPricePerRequest + } + } + + totalCost := unitPrice + if rateMultiplier <= 0 { + rateMultiplier = 1.0 + } + actualCost := totalCost * rateMultiplier + + return &CostBreakdown{ + TotalCost: totalCost, + ActualCost: actualCost, + } +} + // getImageUnitPrice 获取图片单价 func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 { // 优先使用分组配置的价格 @@ -443,7 +559,10 @@ func (s *BillingService) getDefaultImagePrice(model string, imageSize string) fl basePrice = 0.134 } - // 4K 尺寸翻倍 + // 2K 尺寸 1.5 倍,4K 尺寸翻倍 + if imageSize == "2K" { + return basePrice * 1.5 + } if imageSize == "4K" { return basePrice * 2 } diff --git a/backend/internal/service/billing_service_image_test.go b/backend/internal/service/billing_service_image_test.go index 18a6b74d..fa90f6bb 100644 --- a/backend/internal/service/billing_service_image_test.go +++ b/backend/internal/service/billing_service_image_test.go @@ -12,14 +12,14 @@ import ( func TestCalculateImageCost_DefaultPricing(t *testing.T) { svc := &BillingService{} // pricingService 为 nil,使用硬编码默认值 - // 2K 尺寸,默认价格 $0.134 + // 2K 尺寸,默认价格 $0.134 * 1.5 = $0.201 cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.0) - require.InDelta(t, 0.134, cost.TotalCost, 0.0001) - require.InDelta(t, 0.134, cost.ActualCost, 0.0001) + require.InDelta(t, 0.201, cost.TotalCost, 0.0001) + require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 多张图片 cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 3, nil, 1.0) - require.InDelta(t, 0.402, cost.TotalCost, 0.0001) + require.InDelta(t, 0.603, cost.TotalCost, 0.0001) } // TestCalculateImageCost_GroupCustomPricing 测试分组自定义价格 @@ -63,13 +63,13 @@ func TestCalculateImageCost_RateMultiplier(t *testing.T) { // 费率倍数 1.5x cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.5) - require.InDelta(t, 0.134, cost.TotalCost, 0.0001) // TotalCost 不变 - require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // ActualCost = 0.134 * 1.5 + require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // TotalCost = 0.134 * 1.5 + require.InDelta(t, 0.3015, cost.ActualCost, 0.0001) // ActualCost = 0.201 * 1.5 // 费率倍数 2.0x cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 2, nil, 2.0) - require.InDelta(t, 0.268, cost.TotalCost, 0.0001) - require.InDelta(t, 0.536, cost.ActualCost, 0.0001) + require.InDelta(t, 0.402, cost.TotalCost, 0.0001) + require.InDelta(t, 0.804, cost.ActualCost, 0.0001) } // TestCalculateImageCost_ZeroCount 测试 imageCount=0 @@ -95,8 +95,8 @@ func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) { svc := &BillingService{} cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0) - require.InDelta(t, 0.134, cost.TotalCost, 0.0001) - require.InDelta(t, 0.134, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理 + require.InDelta(t, 0.201, cost.TotalCost, 0.0001) + require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理 } // TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格 @@ -127,9 +127,9 @@ func TestGetImageUnitPrice_PartialGroupConfig(t *testing.T) { cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 1, groupConfig, 1.0) require.InDelta(t, 0.10, cost.TotalCost, 0.0001) - // 2K 回退默认价格 $0.134 + // 2K 回退默认价格 $0.201 (1.5倍) cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, groupConfig, 1.0) - require.InDelta(t, 0.134, cost.TotalCost, 0.0001) + require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // 4K 回退默认价格 $0.268 (翻倍) cost = svc.CalculateImageCost("gemini-3-pro-image", "4K", 1, groupConfig, 1.0) @@ -140,10 +140,10 @@ func TestGetImageUnitPrice_PartialGroupConfig(t *testing.T) { func TestGetDefaultImagePrice_FallbackHardcoded(t *testing.T) { svc := &BillingService{} // pricingService 为 nil - // 1K 和 2K 使用相同的默认价格 $0.134 + // 1K 默认价格 $0.134,2K 默认价格 $0.201 (1.5倍) cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 1, nil, 1.0) require.InDelta(t, 0.134, cost.TotalCost, 0.0001) cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.0) - require.InDelta(t, 0.134, cost.TotalCost, 0.0001) + require.InDelta(t, 0.201, cost.TotalCost, 0.0001) } diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go new file mode 100644 index 00000000..5eb278f6 --- /dev/null +++ b/backend/internal/service/billing_service_test.go @@ -0,0 +1,437 @@ +//go:build unit + +package service + +import ( + "math" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func newTestBillingService() *BillingService { + return NewBillingService(&config.Config{}, nil) +} + +func TestCalculateCost_BasicComputation(t *testing.T) { + svc := newTestBillingService() + + // 使用 claude-sonnet-4 的回退价格:Input $3/MTok, Output $15/MTok + tokens := UsageTokens{ + InputTokens: 1000, + OutputTokens: 500, + } + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + // 1000 * 3e-6 = 0.003, 500 * 15e-6 = 0.0075 + expectedInput := 1000 * 3e-6 + expectedOutput := 500 * 15e-6 + require.InDelta(t, expectedInput, cost.InputCost, 1e-10) + require.InDelta(t, expectedOutput, cost.OutputCost, 1e-10) + require.InDelta(t, expectedInput+expectedOutput, cost.TotalCost, 1e-10) + require.InDelta(t, expectedInput+expectedOutput, cost.ActualCost, 1e-10) +} + +func TestCalculateCost_WithCacheTokens(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{ + InputTokens: 1000, + OutputTokens: 500, + CacheCreationTokens: 2000, + CacheReadTokens: 3000, + } + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + expectedCacheCreation := 2000 * 3.75e-6 + expectedCacheRead := 3000 * 0.3e-6 + require.InDelta(t, expectedCacheCreation, cost.CacheCreationCost, 1e-10) + require.InDelta(t, expectedCacheRead, cost.CacheReadCost, 1e-10) + + expectedTotal := cost.InputCost + cost.OutputCost + expectedCacheCreation + expectedCacheRead + require.InDelta(t, expectedTotal, cost.TotalCost, 1e-10) +} + +func TestCalculateCost_RateMultiplier(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + + cost1x, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + cost2x, err := svc.CalculateCost("claude-sonnet-4", tokens, 2.0) + require.NoError(t, err) + + // TotalCost 不受倍率影响,ActualCost 翻倍 + require.InDelta(t, cost1x.TotalCost, cost2x.TotalCost, 1e-10) + require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10) +} + +func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 1000} + + costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0) + require.NoError(t, err) + + costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10) +} + +func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 1000} + + costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0) + require.NoError(t, err) + + costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10) +} + +func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) { + svc := newTestBillingService() + + tests := []struct { + model string + expectedInput float64 + }{ + {"claude-opus-4.5-20250101", 5e-6}, + {"claude-3-opus-20240229", 15e-6}, + {"claude-sonnet-4-20250514", 3e-6}, + {"claude-3-5-sonnet-20241022", 3e-6}, + {"claude-3-5-haiku-20241022", 1e-6}, + {"claude-3-haiku-20240307", 0.25e-6}, + } + + for _, tt := range tests { + pricing, err := svc.GetModelPricing(tt.model) + require.NoError(t, err, "模型 %s", tt.model) + require.InDelta(t, tt.expectedInput, pricing.InputPricePerToken, 1e-12, "模型 %s 输入价格", tt.model) + } +} + +func TestGetModelPricing_CaseInsensitive(t *testing.T) { + svc := newTestBillingService() + + p1, err := svc.GetModelPricing("Claude-Sonnet-4") + require.NoError(t, err) + + p2, err := svc.GetModelPricing("claude-sonnet-4") + require.NoError(t, err) + + require.Equal(t, p1.InputPricePerToken, p2.InputPricePerToken) +} + +func TestGetModelPricing_UnknownModelFallsBackToSonnet(t *testing.T) { + svc := newTestBillingService() + + // 不包含 opus/sonnet/haiku 关键词的 Claude 模型会走默认 Sonnet 价格 + pricing, err := svc.GetModelPricing("claude-unknown-model") + require.NoError(t, err) + require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12) +} + +func TestCalculateCostWithLongContext_BelowThreshold(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{ + InputTokens: 50000, + OutputTokens: 1000, + CacheReadTokens: 100000, + } + // 总输入 150k < 200k 阈值,应走正常计费 + cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0) + require.NoError(t, err) + + normalCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, normalCost.ActualCost, cost.ActualCost, 1e-10) +} + +func TestCalculateCostWithLongContext_AboveThreshold_CacheExceedsThreshold(t *testing.T) { + svc := newTestBillingService() + + // 缓存 210k + 输入 10k = 220k > 200k 阈值 + // 缓存已超阈值:范围内 200k 缓存,范围外 10k 缓存 + 10k 输入 + tokens := UsageTokens{ + InputTokens: 10000, + OutputTokens: 1000, + CacheReadTokens: 210000, + } + cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0) + require.NoError(t, err) + + // 范围内:200k cache + 0 input + 1k output + inRange, _ := svc.CalculateCost("claude-sonnet-4", UsageTokens{ + InputTokens: 0, + OutputTokens: 1000, + CacheReadTokens: 200000, + }, 1.0) + + // 范围外:10k cache + 10k input,倍率 2.0 + outRange, _ := svc.CalculateCost("claude-sonnet-4", UsageTokens{ + InputTokens: 10000, + CacheReadTokens: 10000, + }, 2.0) + + require.InDelta(t, inRange.ActualCost+outRange.ActualCost, cost.ActualCost, 1e-10) +} + +func TestCalculateCostWithLongContext_AboveThreshold_CacheBelowThreshold(t *testing.T) { + svc := newTestBillingService() + + // 缓存 100k + 输入 150k = 250k > 200k 阈值 + // 缓存未超阈值:范围内 100k 缓存 + 100k 输入,范围外 50k 输入 + tokens := UsageTokens{ + InputTokens: 150000, + OutputTokens: 1000, + CacheReadTokens: 100000, + } + cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0) + require.NoError(t, err) + + require.True(t, cost.ActualCost > 0, "费用应大于 0") + + // 正常费用不含长上下文 + normalCost, _ := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.True(t, cost.ActualCost > normalCost.ActualCost, "长上下文费用应高于正常费用") +} + +func TestCalculateCostWithLongContext_DisabledThreshold(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 300000, CacheReadTokens: 0} + + // threshold <= 0 应禁用长上下文计费 + cost1, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 0, 2.0) + require.NoError(t, err) + + cost2, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, cost2.ActualCost, cost1.ActualCost, 1e-10) +} + +func TestCalculateCostWithLongContext_ExtraMultiplierLessEqualOne(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 300000} + + // extraMultiplier <= 1 应禁用长上下文计费 + cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 1.0) + require.NoError(t, err) + + normalCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, normalCost.ActualCost, cost.ActualCost, 1e-10) +} + +func TestCalculateImageCost(t *testing.T) { + svc := newTestBillingService() + + price := 0.134 + cfg := &ImagePriceConfig{Price1K: &price} + cost := svc.CalculateImageCost("gpt-image-1", "1K", 3, cfg, 1.0) + + require.InDelta(t, 0.134*3, cost.TotalCost, 1e-10) + require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10) +} + +func TestCalculateSoraVideoCost(t *testing.T) { + svc := newTestBillingService() + + price := 0.5 + cfg := &SoraPriceConfig{VideoPricePerRequest: &price} + cost := svc.CalculateSoraVideoCost("sora-video", cfg, 1.0) + + require.InDelta(t, 0.5, cost.TotalCost, 1e-10) +} + +func TestCalculateSoraVideoCost_HDModel(t *testing.T) { + svc := newTestBillingService() + + hdPrice := 1.0 + normalPrice := 0.5 + cfg := &SoraPriceConfig{ + VideoPricePerRequest: &normalPrice, + VideoPricePerRequestHD: &hdPrice, + } + cost := svc.CalculateSoraVideoCost("sora2pro-hd", cfg, 1.0) + require.InDelta(t, 1.0, cost.TotalCost, 1e-10) +} + +func TestIsModelSupported(t *testing.T) { + svc := newTestBillingService() + + require.True(t, svc.IsModelSupported("claude-sonnet-4")) + require.True(t, svc.IsModelSupported("Claude-Opus-4.5")) + require.True(t, svc.IsModelSupported("claude-3-haiku")) + require.False(t, svc.IsModelSupported("gpt-4o")) + require.False(t, svc.IsModelSupported("gemini-pro")) +} + +func TestCalculateCost_ZeroTokens(t *testing.T) { + svc := newTestBillingService() + + cost, err := svc.CalculateCost("claude-sonnet-4", UsageTokens{}, 1.0) + require.NoError(t, err) + require.Equal(t, 0.0, cost.TotalCost) + require.Equal(t, 0.0, cost.ActualCost) +} + +func TestCalculateCostWithConfig(t *testing.T) { + cfg := &config.Config{} + cfg.Default.RateMultiplier = 1.5 + svc := NewBillingService(cfg, nil) + + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + cost, err := svc.CalculateCostWithConfig("claude-sonnet-4", tokens) + require.NoError(t, err) + + expected, _ := svc.CalculateCost("claude-sonnet-4", tokens, 1.5) + require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10) +} + +func TestCalculateCostWithConfig_ZeroMultiplier(t *testing.T) { + cfg := &config.Config{} + cfg.Default.RateMultiplier = 0 + svc := NewBillingService(cfg, nil) + + tokens := UsageTokens{InputTokens: 1000} + cost, err := svc.CalculateCostWithConfig("claude-sonnet-4", tokens) + require.NoError(t, err) + + // 倍率 <=0 时默认 1.0 + expected, _ := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10) +} + +func TestGetEstimatedCost(t *testing.T) { + svc := newTestBillingService() + + est, err := svc.GetEstimatedCost("claude-sonnet-4", 1000, 500) + require.NoError(t, err) + require.True(t, est > 0) +} + +func TestListSupportedModels(t *testing.T) { + svc := newTestBillingService() + + models := svc.ListSupportedModels() + require.NotEmpty(t, models) + require.GreaterOrEqual(t, len(models), 6) +} + +func TestGetPricingServiceStatus_NilService(t *testing.T) { + svc := newTestBillingService() + + status := svc.GetPricingServiceStatus() + require.NotNil(t, status) + require.Equal(t, "using fallback", status["last_updated"]) +} + +func TestForceUpdatePricing_NilService(t *testing.T) { + svc := newTestBillingService() + + err := svc.ForceUpdatePricing() + require.Error(t, err) + require.Contains(t, err.Error(), "not initialized") +} + +func TestCalculateSoraImageCost(t *testing.T) { + svc := newTestBillingService() + + price360 := 0.05 + price540 := 0.08 + cfg := &SoraPriceConfig{ImagePrice360: &price360, ImagePrice540: &price540} + + cost := svc.CalculateSoraImageCost("360", 2, cfg, 1.0) + require.InDelta(t, 0.10, cost.TotalCost, 1e-10) + + cost540 := svc.CalculateSoraImageCost("540", 1, cfg, 2.0) + require.InDelta(t, 0.08, cost540.TotalCost, 1e-10) + require.InDelta(t, 0.16, cost540.ActualCost, 1e-10) +} + +func TestCalculateSoraImageCost_ZeroCount(t *testing.T) { + svc := newTestBillingService() + cost := svc.CalculateSoraImageCost("360", 0, nil, 1.0) + require.Equal(t, 0.0, cost.TotalCost) +} + +func TestCalculateSoraVideoCost_NilConfig(t *testing.T) { + svc := newTestBillingService() + cost := svc.CalculateSoraVideoCost("sora-video", nil, 1.0) + require.Equal(t, 0.0, cost.TotalCost) +} + +func TestCalculateCostWithLongContext_PropagatesError(t *testing.T) { + // 使用空的 fallback prices 让 GetModelPricing 失败 + svc := &BillingService{ + cfg: &config.Config{}, + fallbackPrices: make(map[string]*ModelPricing), + } + + tokens := UsageTokens{InputTokens: 300000, CacheReadTokens: 0} + _, err := svc.CalculateCostWithLongContext("unknown-model", tokens, 1.0, 200000, 2.0) + require.Error(t, err) + require.Contains(t, err.Error(), "pricing not found") +} + +func TestCalculateCost_SupportsCacheBreakdown(t *testing.T) { + svc := &BillingService{ + cfg: &config.Config{}, + fallbackPrices: map[string]*ModelPricing{ + "claude-sonnet-4": { + InputPricePerToken: 3e-6, + OutputPricePerToken: 15e-6, + SupportsCacheBreakdown: true, + CacheCreation5mPrice: 4e-6, // per token + CacheCreation1hPrice: 5e-6, // per token + }, + }, + } + + tokens := UsageTokens{ + InputTokens: 1000, + OutputTokens: 500, + CacheCreation5mTokens: 100000, + CacheCreation1hTokens: 50000, + } + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + expected5m := float64(tokens.CacheCreation5mTokens) * 4e-6 + expected1h := float64(tokens.CacheCreation1hTokens) * 5e-6 + require.InDelta(t, expected5m+expected1h, cost.CacheCreationCost, 1e-10) +} + +func TestCalculateCost_LargeTokenCount(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{ + InputTokens: 1_000_000, + OutputTokens: 1_000_000, + } + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + // Input: 1M * 3e-6 = $3, Output: 1M * 15e-6 = $15 + require.InDelta(t, 3.0, cost.InputCost, 1e-6) + require.InDelta(t, 15.0, cost.OutputCost, 1e-6) + require.False(t, math.IsNaN(cost.TotalCost)) + require.False(t, math.IsInf(cost.TotalCost, 0)) +} diff --git a/backend/internal/service/claude_code_detection_test.go b/backend/internal/service/claude_code_detection_test.go new file mode 100644 index 00000000..ff7ad7f4 --- /dev/null +++ b/backend/internal/service/claude_code_detection_test.go @@ -0,0 +1,282 @@ +//go:build unit + +package service + +import ( + "context" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func newTestValidator() *ClaudeCodeValidator { + return NewClaudeCodeValidator() +} + +// validClaudeCodeBody 构造一个完整有效的 Claude Code 请求体 +func validClaudeCodeBody() map[string]any { + return map[string]any{ + "model": "claude-sonnet-4-20250514", + "system": []any{ + map[string]any{ + "type": "text", + "text": "You are Claude Code, Anthropic's official CLI for Claude.", + }, + }, + "metadata": map[string]any{ + "user_id": "user_" + "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + "_account__session_" + "12345678-1234-1234-1234-123456789abc", + }, + } +} + +func TestValidate_ClaudeCLIUserAgent(t *testing.T) { + v := newTestValidator() + + tests := []struct { + name string + ua string + want bool + }{ + {"标准版本号", "claude-cli/1.0.0", true}, + {"多位版本号", "claude-cli/12.34.56", true}, + {"大写开头", "Claude-CLI/1.0.0", true}, + {"非 claude-cli", "curl/7.64.1", false}, + {"空 User-Agent", "", false}, + {"部分匹配", "not-claude-cli/1.0.0", false}, + {"缺少版本号", "claude-cli/", false}, + {"版本格式不对", "claude-cli/1.0", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, v.ValidateUserAgent(tt.ua), "UA: %q", tt.ua) + }) + } +} + +func TestValidate_NonMessagesPath_UAOnly(t *testing.T) { + v := newTestValidator() + + // 非 messages 路径只检查 UA + req := httptest.NewRequest("GET", "/v1/models", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + + result := v.Validate(req, nil) + require.True(t, result, "非 messages 路径只需 UA 匹配") +} + +func TestValidate_NonMessagesPath_InvalidUA(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("GET", "/v1/models", nil) + req.Header.Set("User-Agent", "curl/7.64.1") + + result := v.Validate(req, nil) + require.False(t, result, "UA 不匹配时应返回 false") +} + +func TestValidate_MessagesPath_FullValid(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15") + req.Header.Set("anthropic-version", "2023-06-01") + + result := v.Validate(req, validClaudeCodeBody()) + require.True(t, result, "完整有效请求应通过") +} + +func TestValidate_MessagesPath_MissingHeaders(t *testing.T) { + v := newTestValidator() + body := validClaudeCodeBody() + + tests := []struct { + name string + missingHeader string + }{ + {"缺少 X-App", "X-App"}, + {"缺少 anthropic-beta", "anthropic-beta"}, + {"缺少 anthropic-version", "anthropic-version"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "beta") + req.Header.Set("anthropic-version", "2023-06-01") + req.Header.Del(tt.missingHeader) + + result := v.Validate(req, body) + require.False(t, result, "缺少 %s 应返回 false", tt.missingHeader) + }) + } +} + +func TestValidate_MessagesPath_InvalidMetadataUserID(t *testing.T) { + v := newTestValidator() + + tests := []struct { + name string + metadata map[string]any + }{ + {"缺少 metadata", nil}, + {"缺少 user_id", map[string]any{"other": "value"}}, + {"空 user_id", map[string]any{"user_id": ""}}, + {"格式错误", map[string]any{"user_id": "invalid-format"}}, + {"hex 长度不足", map[string]any{"user_id": "user_abc_account__session_uuid"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "beta") + req.Header.Set("anthropic-version", "2023-06-01") + + body := map[string]any{ + "model": "claude-sonnet-4", + "system": []any{ + map[string]any{ + "type": "text", + "text": "You are Claude Code, Anthropic's official CLI for Claude.", + }, + }, + } + if tt.metadata != nil { + body["metadata"] = tt.metadata + } + + result := v.Validate(req, body) + require.False(t, result, "metadata.user_id: %v", tt.metadata) + }) + } +} + +func TestValidate_MessagesPath_InvalidSystemPrompt(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "beta") + req.Header.Set("anthropic-version", "2023-06-01") + + body := map[string]any{ + "model": "claude-sonnet-4", + "system": []any{ + map[string]any{ + "type": "text", + "text": "Generate JSON data for testing database migrations.", + }, + }, + "metadata": map[string]any{ + "user_id": "user_" + "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + "_account__session_12345678-1234-1234-1234-123456789abc", + }, + } + + result := v.Validate(req, body) + require.False(t, result, "无关系统提示词应返回 false") +} + +func TestValidate_MaxTokensOneHaikuBypass(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + // 不设置 X-App 等头,通过 context 标记为 haiku 探测请求 + ctx := context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true) + req = req.WithContext(ctx) + + // 即使 body 不包含 system prompt,也应通过 + result := v.Validate(req, map[string]any{"model": "claude-3-haiku", "max_tokens": 1}) + require.True(t, result, "max_tokens=1+haiku 探测请求应绕过严格验证") +} + +func TestSystemPromptSimilarity(t *testing.T) { + v := newTestValidator() + + tests := []struct { + name string + prompt string + want bool + }{ + {"精确匹配", "You are Claude Code, Anthropic's official CLI for Claude.", true}, + {"带多余空格", "You are Claude Code, Anthropic's official CLI for Claude.", true}, + {"Agent SDK 模板", "You are a Claude agent, built on Anthropic's Claude Agent SDK.", true}, + {"文件搜索专家模板", "You are a file search specialist for Claude Code, Anthropic's official CLI for Claude.", true}, + {"对话摘要模板", "You are a helpful AI assistant tasked with summarizing conversations.", true}, + {"交互式 CLI 模板", "You are an interactive CLI tool that helps users", true}, + {"无关文本", "Write me a poem about cats", false}, + {"空文本", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body := map[string]any{ + "model": "claude-sonnet-4", + "system": []any{ + map[string]any{"type": "text", "text": tt.prompt}, + }, + } + result := v.IncludesClaudeCodeSystemPrompt(body) + require.Equal(t, tt.want, result, "提示词: %q", tt.prompt) + }) + } +} + +func TestDiceCoefficient(t *testing.T) { + tests := []struct { + name string + a string + b string + want float64 + tol float64 + }{ + {"相同字符串", "hello", "hello", 1.0, 0.001}, + {"完全不同", "abc", "xyz", 0.0, 0.001}, + {"空字符串", "", "hello", 0.0, 0.001}, + {"单字符", "a", "b", 0.0, 0.001}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := diceCoefficient(tt.a, tt.b) + require.InDelta(t, tt.want, result, tt.tol) + }) + } +} + +func TestIsClaudeCodeClient_Context(t *testing.T) { + ctx := context.Background() + + // 默认应为 false + require.False(t, IsClaudeCodeClient(ctx)) + + // 设置为 true + ctx = SetClaudeCodeClient(ctx, true) + require.True(t, IsClaudeCodeClient(ctx)) + + // 设置为 false + ctx = SetClaudeCodeClient(ctx, false) + require.False(t, IsClaudeCodeClient(ctx)) +} + +func TestValidate_NilBody_MessagesPath(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "beta") + req.Header.Set("anthropic-version", "2023-06-01") + + result := v.Validate(req, nil) + require.False(t, result, "nil body 的 messages 请求应返回 false") +} diff --git a/backend/internal/service/claude_code_validator.go b/backend/internal/service/claude_code_validator.go index 6d06c83e..f71098b1 100644 --- a/backend/internal/service/claude_code_validator.go +++ b/backend/internal/service/claude_code_validator.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "regexp" + "strconv" "strings" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" @@ -17,6 +18,9 @@ var ( // User-Agent 匹配: claude-cli/x.x.x (仅支持官方 CLI,大小写不敏感) claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`) + // 带捕获组的版本提取正则 + claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`) + // metadata.user_id 格式: user_{64位hex}_account__session_{uuid} userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`) @@ -78,7 +82,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo // Step 3: 检查 max_tokens=1 + haiku 探测请求绕过 // 这类请求用于 Claude Code 验证 API 连通性,不携带 system prompt - if isMaxTokensOneHaiku, ok := r.Context().Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok && isMaxTokensOneHaiku { + if isMaxTokensOneHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(r.Context()); ok && isMaxTokensOneHaiku { return true // 绕过 system prompt 检查,UA 已在 Step 1 验证 } @@ -270,3 +274,55 @@ func IsClaudeCodeClient(ctx context.Context) bool { func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context { return context.WithValue(ctx, ctxkey.IsClaudeCodeClient, isClaudeCode) } + +// ExtractVersion 从 User-Agent 中提取 Claude Code 版本号 +// 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串 +func (v *ClaudeCodeValidator) ExtractVersion(ua string) string { + matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua) + if len(matches) >= 2 { + return matches[1] + } + return "" +} + +// SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中 +func SetClaudeCodeVersion(ctx context.Context, version string) context.Context { + return context.WithValue(ctx, ctxkey.ClaudeCodeVersion, version) +} + +// GetClaudeCodeVersion 从 context 中获取 Claude Code 版本号 +func GetClaudeCodeVersion(ctx context.Context) string { + if v, ok := ctx.Value(ctxkey.ClaudeCodeVersion).(string); ok { + return v + } + return "" +} + +// CompareVersions 比较两个 semver 版本号 +// 返回: -1 (a < b), 0 (a == b), 1 (a > b) +func CompareVersions(a, b string) int { + aParts := parseSemver(a) + bParts := parseSemver(b) + for i := 0; i < 3; i++ { + if aParts[i] < bParts[i] { + return -1 + } + if aParts[i] > bParts[i] { + return 1 + } + } + return 0 +} + +// parseSemver 解析 semver 版本号为 [major, minor, patch] +func parseSemver(v string) [3]int { + v = strings.TrimPrefix(v, "v") + parts := strings.Split(v, ".") + result := [3]int{0, 0, 0} + for i := 0; i < len(parts) && i < 3; i++ { + if parsed, err := strconv.Atoi(parts[i]); err == nil { + result[i] = parsed + } + } + return result +} diff --git a/backend/internal/service/claude_code_validator_test.go b/backend/internal/service/claude_code_validator_test.go index a4cd1886..f87c56e8 100644 --- a/backend/internal/service/claude_code_validator_test.go +++ b/backend/internal/service/claude_code_validator_test.go @@ -56,3 +56,51 @@ func TestClaudeCodeValidator_NonMessagesPathUAOnly(t *testing.T) { ok := validator.Validate(req, nil) require.True(t, ok) } + +func TestExtractVersion(t *testing.T) { + v := NewClaudeCodeValidator() + tests := []struct { + ua string + want string + }{ + {"claude-cli/2.1.22 (darwin; arm64)", "2.1.22"}, + {"claude-cli/1.0.0", "1.0.0"}, + {"Claude-CLI/3.10.5 (linux; x86_64)", "3.10.5"}, // 大小写不敏感 + {"curl/8.0.0", ""}, // 非 Claude CLI + {"", ""}, // 空字符串 + {"claude-cli/", ""}, // 无版本号 + {"claude-cli/2.1.22-beta", "2.1.22"}, // 带后缀仍提取主版本号 + } + for _, tt := range tests { + got := v.ExtractVersion(tt.ua) + require.Equal(t, tt.want, got, "ExtractVersion(%q)", tt.ua) + } +} + +func TestCompareVersions(t *testing.T) { + tests := []struct { + a, b string + want int + }{ + {"2.1.0", "2.1.0", 0}, // 相等 + {"2.1.1", "2.1.0", 1}, // patch 更大 + {"2.0.0", "2.1.0", -1}, // minor 更小 + {"3.0.0", "2.99.99", 1}, // major 更大 + {"1.0.0", "2.0.0", -1}, // major 更小 + {"0.0.1", "0.0.0", 1}, // patch 差异 + {"", "1.0.0", -1}, // 空字符串 vs 正常版本 + {"v2.1.0", "2.1.0", 0}, // v 前缀处理 + } + for _, tt := range tests { + got := CompareVersions(tt.a, tt.b) + require.Equal(t, tt.want, got, "CompareVersions(%q, %q)", tt.a, tt.b) + } +} + +func TestSetGetClaudeCodeVersion(t *testing.T) { + ctx := context.Background() + require.Equal(t, "", GetClaudeCodeVersion(ctx), "empty context should return empty string") + + ctx = SetClaudeCodeVersion(ctx, "2.1.63") + require.Equal(t, "2.1.63", GetClaudeCodeVersion(ctx)) +} diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go index d5cb2025..4dcf84e0 100644 --- a/backend/internal/service/concurrency_service.go +++ b/backend/internal/service/concurrency_service.go @@ -3,10 +3,13 @@ package service import ( "context" "crypto/rand" - "encoding/hex" - "fmt" - "log" + "encoding/binary" + "os" + "strconv" + "sync/atomic" "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) // ConcurrencyCache 定义并发控制的缓存接口 @@ -17,6 +20,7 @@ type ConcurrencyCache interface { AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) + GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) // 账号等待队列(账号级) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) @@ -41,15 +45,25 @@ type ConcurrencyCache interface { CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error } -// generateRequestID generates a unique request ID for concurrency slot tracking -// Uses 8 random bytes (16 hex chars) for uniqueness -func generateRequestID() string { +var ( + requestIDPrefix = initRequestIDPrefix() + requestIDCounter atomic.Uint64 +) + +func initRequestIDPrefix() string { b := make([]byte, 8) - if _, err := rand.Read(b); err != nil { - // Fallback to nanosecond timestamp (extremely rare case) - return fmt.Sprintf("%x", time.Now().UnixNano()) + if _, err := rand.Read(b); err == nil { + return "r" + strconv.FormatUint(binary.BigEndian.Uint64(b), 36) } - return hex.EncodeToString(b) + fallback := uint64(time.Now().UnixNano()) ^ (uint64(os.Getpid()) << 16) + return "r" + strconv.FormatUint(fallback, 36) +} + +// generateRequestID generates a unique request ID for concurrency slot tracking. +// Format: {process_random_prefix}-{base36_counter} +func generateRequestID() string { + seq := requestIDCounter.Add(1) + return requestIDPrefix + "-" + strconv.FormatUint(seq, 36) } const ( @@ -124,7 +138,7 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := s.cache.ReleaseAccountSlot(bgCtx, accountID, requestID); err != nil { - log.Printf("Warning: failed to release account slot for %d (req=%s): %v", accountID, requestID, err) + logger.LegacyPrintf("service.concurrency", "Warning: failed to release account slot for %d (req=%s): %v", accountID, requestID, err) } }, }, nil @@ -163,7 +177,7 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := s.cache.ReleaseUserSlot(bgCtx, userID, requestID); err != nil { - log.Printf("Warning: failed to release user slot for %d (req=%s): %v", userID, requestID, err) + logger.LegacyPrintf("service.concurrency", "Warning: failed to release user slot for %d (req=%s): %v", userID, requestID, err) } }, }, nil @@ -191,7 +205,7 @@ func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int6 result, err := s.cache.IncrementWaitCount(ctx, userID, maxWait) if err != nil { // On error, allow the request to proceed (fail open) - log.Printf("Warning: increment wait count failed for user %d: %v", userID, err) + logger.LegacyPrintf("service.concurrency", "Warning: increment wait count failed for user %d: %v", userID, err) return true, nil } return result, nil @@ -209,7 +223,7 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6 defer cancel() if err := s.cache.DecrementWaitCount(bgCtx, userID); err != nil { - log.Printf("Warning: decrement wait count failed for user %d: %v", userID, err) + logger.LegacyPrintf("service.concurrency", "Warning: decrement wait count failed for user %d: %v", userID, err) } } @@ -221,7 +235,7 @@ func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, acco result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait) if err != nil { - log.Printf("Warning: increment wait count failed for account %d: %v", accountID, err) + logger.LegacyPrintf("service.concurrency", "Warning: increment wait count failed for account %d: %v", accountID, err) return true, nil } return result, nil @@ -237,7 +251,7 @@ func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, acco defer cancel() if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil { - log.Printf("Warning: decrement wait count failed for account %d: %v", accountID, err) + logger.LegacyPrintf("service.concurrency", "Warning: decrement wait count failed for account %d: %v", accountID, err) } } @@ -293,7 +307,7 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor accounts, err := accountRepo.ListSchedulable(listCtx) cancel() if err != nil { - log.Printf("Warning: list schedulable accounts failed: %v", err) + logger.LegacyPrintf("service.concurrency", "Warning: list schedulable accounts failed: %v", err) return } for _, account := range accounts { @@ -301,7 +315,7 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID) accountCancel() if err != nil { - log.Printf("Warning: cleanup expired slots failed for account %d: %v", account.ID, err) + logger.LegacyPrintf("service.concurrency", "Warning: cleanup expired slots failed for account %d: %v", account.ID, err) } } } @@ -320,16 +334,15 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor // GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts // Returns a map of accountID -> current concurrency count func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { - result := make(map[int64]int) - - for _, accountID := range accountIDs { - count, err := s.cache.GetAccountConcurrency(ctx, accountID) - if err != nil { - // If key doesn't exist in Redis, count is 0 - count = 0 - } - result[accountID] = count + if len(accountIDs) == 0 { + return map[int64]int{}, nil } - - return result, nil + if s.cache == nil { + result := make(map[int64]int, len(accountIDs)) + for _, accountID := range accountIDs { + result[accountID] = 0 + } + return result, nil + } + return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs) } diff --git a/backend/internal/service/concurrency_service_test.go b/backend/internal/service/concurrency_service_test.go new file mode 100644 index 00000000..9ba43d93 --- /dev/null +++ b/backend/internal/service/concurrency_service_test.go @@ -0,0 +1,311 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩 +type stubConcurrencyCacheForTest struct { + acquireResult bool + acquireErr error + releaseErr error + concurrency int + concurrencyErr error + waitAllowed bool + waitErr error + waitCount int + waitCountErr error + loadBatch map[int64]*AccountLoadInfo + loadBatchErr error + usersLoadBatch map[int64]*UserLoadInfo + usersLoadErr error + cleanupErr error + + // 记录调用 + releasedAccountIDs []int64 + releasedRequestIDs []string +} + +var _ ConcurrencyCache = (*stubConcurrencyCacheForTest)(nil) + +func (c *stubConcurrencyCacheForTest) AcquireAccountSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) { + return c.acquireResult, c.acquireErr +} +func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, accountID int64, requestID string) error { + c.releasedAccountIDs = append(c.releasedAccountIDs, accountID) + c.releasedRequestIDs = append(c.releasedRequestIDs, requestID) + return c.releaseErr +} +func (c *stubConcurrencyCacheForTest) GetAccountConcurrency(_ context.Context, _ int64) (int, error) { + return c.concurrency, c.concurrencyErr +} +func (c *stubConcurrencyCacheForTest) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) { + result := make(map[int64]int, len(accountIDs)) + for _, accountID := range accountIDs { + if c.concurrencyErr != nil { + return nil, c.concurrencyErr + } + result[accountID] = c.concurrency + } + return result, nil +} +func (c *stubConcurrencyCacheForTest) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) { + return c.waitAllowed, c.waitErr +} +func (c *stubConcurrencyCacheForTest) DecrementAccountWaitCount(_ context.Context, _ int64) error { + return nil +} +func (c *stubConcurrencyCacheForTest) GetAccountWaitingCount(_ context.Context, _ int64) (int, error) { + return c.waitCount, c.waitCountErr +} +func (c *stubConcurrencyCacheForTest) AcquireUserSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) { + return c.acquireResult, c.acquireErr +} +func (c *stubConcurrencyCacheForTest) ReleaseUserSlot(_ context.Context, _ int64, _ string) error { + return c.releaseErr +} +func (c *stubConcurrencyCacheForTest) GetUserConcurrency(_ context.Context, _ int64) (int, error) { + return c.concurrency, c.concurrencyErr +} +func (c *stubConcurrencyCacheForTest) IncrementWaitCount(_ context.Context, _ int64, _ int) (bool, error) { + return c.waitAllowed, c.waitErr +} +func (c *stubConcurrencyCacheForTest) DecrementWaitCount(_ context.Context, _ int64) error { + return nil +} +func (c *stubConcurrencyCacheForTest) GetAccountsLoadBatch(_ context.Context, _ []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + return c.loadBatch, c.loadBatchErr +} +func (c *stubConcurrencyCacheForTest) GetUsersLoadBatch(_ context.Context, _ []UserWithConcurrency) (map[int64]*UserLoadInfo, error) { + return c.usersLoadBatch, c.usersLoadErr +} +func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Context, _ int64) error { + return c.cleanupErr +} + +func TestAcquireAccountSlot_Success(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireResult: true} + svc := NewConcurrencyService(cache) + + result, err := svc.AcquireAccountSlot(context.Background(), 1, 5) + require.NoError(t, err) + require.True(t, result.Acquired) + require.NotNil(t, result.ReleaseFunc) +} + +func TestAcquireAccountSlot_Failure(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireResult: false} + svc := NewConcurrencyService(cache) + + result, err := svc.AcquireAccountSlot(context.Background(), 1, 5) + require.NoError(t, err) + require.False(t, result.Acquired) + require.Nil(t, result.ReleaseFunc) +} + +func TestAcquireAccountSlot_UnlimitedConcurrency(t *testing.T) { + svc := NewConcurrencyService(&stubConcurrencyCacheForTest{}) + + for _, maxConcurrency := range []int{0, -1} { + result, err := svc.AcquireAccountSlot(context.Background(), 1, maxConcurrency) + require.NoError(t, err) + require.True(t, result.Acquired, "maxConcurrency=%d 应无限制通过", maxConcurrency) + require.NotNil(t, result.ReleaseFunc, "ReleaseFunc 应为 no-op 函数") + } +} + +func TestAcquireAccountSlot_CacheError(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireErr: errors.New("redis down")} + svc := NewConcurrencyService(cache) + + result, err := svc.AcquireAccountSlot(context.Background(), 1, 5) + require.Error(t, err) + require.Nil(t, result) +} + +func TestAcquireAccountSlot_ReleaseDecrements(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireResult: true} + svc := NewConcurrencyService(cache) + + result, err := svc.AcquireAccountSlot(context.Background(), 42, 5) + require.NoError(t, err) + require.True(t, result.Acquired) + + // 调用 ReleaseFunc 应释放槽位 + result.ReleaseFunc() + + require.Len(t, cache.releasedAccountIDs, 1) + require.Equal(t, int64(42), cache.releasedAccountIDs[0]) + require.Len(t, cache.releasedRequestIDs, 1) + require.NotEmpty(t, cache.releasedRequestIDs[0], "requestID 不应为空") +} + +func TestAcquireUserSlot_IndependentFromAccount(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireResult: true} + svc := NewConcurrencyService(cache) + + // 用户槽位获取应独立于账户槽位 + result, err := svc.AcquireUserSlot(context.Background(), 100, 3) + require.NoError(t, err) + require.True(t, result.Acquired) + require.NotNil(t, result.ReleaseFunc) +} + +func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) { + svc := NewConcurrencyService(&stubConcurrencyCacheForTest{}) + + result, err := svc.AcquireUserSlot(context.Background(), 1, 0) + require.NoError(t, err) + require.True(t, result.Acquired) +} + +func TestGenerateRequestID_UsesStablePrefixAndMonotonicCounter(t *testing.T) { + id1 := generateRequestID() + id2 := generateRequestID() + require.NotEmpty(t, id1) + require.NotEmpty(t, id2) + + p1 := strings.Split(id1, "-") + p2 := strings.Split(id2, "-") + require.Len(t, p1, 2) + require.Len(t, p2, 2) + require.Equal(t, p1[0], p2[0], "同一进程前缀应保持一致") + + n1, err := strconv.ParseUint(p1[1], 36, 64) + require.NoError(t, err) + n2, err := strconv.ParseUint(p2[1], 36, 64) + require.NoError(t, err) + require.Equal(t, n1+1, n2, "计数器应单调递增") +} + +func TestGetAccountsLoadBatch_ReturnsCorrectData(t *testing.T) { + expected := map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, CurrentConcurrency: 3, WaitingCount: 0, LoadRate: 60}, + 2: {AccountID: 2, CurrentConcurrency: 5, WaitingCount: 2, LoadRate: 100}, + } + cache := &stubConcurrencyCacheForTest{loadBatch: expected} + svc := NewConcurrencyService(cache) + + accounts := []AccountWithConcurrency{ + {ID: 1, MaxConcurrency: 5}, + {ID: 2, MaxConcurrency: 5}, + } + result, err := svc.GetAccountsLoadBatch(context.Background(), accounts) + require.NoError(t, err) + require.Equal(t, expected, result) +} + +func TestGetAccountsLoadBatch_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + + result, err := svc.GetAccountsLoadBatch(context.Background(), nil) + require.NoError(t, err) + require.Empty(t, result) +} + +func TestIncrementWaitCount_Success(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitAllowed: true} + svc := NewConcurrencyService(cache) + + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.True(t, allowed) +} + +func TestIncrementWaitCount_QueueFull(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitAllowed: false} + svc := NewConcurrencyService(cache) + + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.False(t, allowed) +} + +func TestIncrementWaitCount_FailOpen(t *testing.T) { + // Redis 错误时应 fail-open(允许请求通过) + cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis timeout")} + svc := NewConcurrencyService(cache) + + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err, "Redis 错误不应传播") + require.True(t, allowed, "Redis 错误时应 fail-open") +} + +func TestIncrementWaitCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.True(t, allowed, "nil cache 应 fail-open") +} + +func TestCalculateMaxWait(t *testing.T) { + tests := []struct { + concurrency int + expected int + }{ + {5, 25}, // 5 + 20 + {1, 21}, // 1 + 20 + {0, 21}, // min(1) + 20 + {-1, 21}, // min(1) + 20 + {10, 30}, // 10 + 20 + } + for _, tt := range tests { + result := CalculateMaxWait(tt.concurrency) + require.Equal(t, tt.expected, result, "CalculateMaxWait(%d)", tt.concurrency) + } +} + +func TestGetAccountWaitingCount(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitCount: 5} + svc := NewConcurrencyService(cache) + + count, err := svc.GetAccountWaitingCount(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, 5, count) +} + +func TestGetAccountWaitingCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + + count, err := svc.GetAccountWaitingCount(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, 0, count) +} + +func TestGetAccountConcurrencyBatch(t *testing.T) { + cache := &stubConcurrencyCacheForTest{concurrency: 3} + svc := NewConcurrencyService(cache) + + result, err := svc.GetAccountConcurrencyBatch(context.Background(), []int64{1, 2, 3}) + require.NoError(t, err) + require.Len(t, result, 3) + for _, id := range []int64{1, 2, 3} { + require.Equal(t, 3, result[id]) + } +} + +func TestIncrementAccountWaitCount_FailOpen(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis error")} + svc := NewConcurrencyService(cache) + + allowed, err := svc.IncrementAccountWaitCount(context.Background(), 1, 10) + require.NoError(t, err, "Redis 错误不应传播") + require.True(t, allowed, "Redis 错误时应 fail-open") +} + +func TestIncrementAccountWaitCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + + allowed, err := svc.IncrementAccountWaitCount(context.Background(), 1, 10) + require.NoError(t, err) + require.True(t, allowed) +} diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go index 040b2357..6a916740 100644 --- a/backend/internal/service/crs_sync_service.go +++ b/backend/internal/service/crs_sync_service.go @@ -221,7 +221,7 @@ func (s *CRSSyncService) fetchCRSExport(ctx context.Context, baseURL, username, AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts, }) if err != nil { - client = &http.Client{Timeout: 20 * time.Second} + return nil, fmt.Errorf("create http client failed: %w", err) } adminToken, err := crsLogin(ctx, client, normalizedURL, username, password) diff --git a/backend/internal/service/dashboard_aggregation_service.go b/backend/internal/service/dashboard_aggregation_service.go index 10c68868..a67f8532 100644 --- a/backend/internal/service/dashboard_aggregation_service.go +++ b/backend/internal/service/dashboard_aggregation_service.go @@ -3,11 +3,12 @@ package service import ( "context" "errors" - "log" + "log/slog" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) const ( @@ -65,7 +66,7 @@ func (s *DashboardAggregationService) Start() { return } if !s.cfg.Enabled { - log.Printf("[DashboardAggregation] 聚合作业已禁用") + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合作业已禁用") return } @@ -81,9 +82,9 @@ func (s *DashboardAggregationService) Start() { s.timingWheel.ScheduleRecurring("dashboard:aggregation", interval, func() { s.runScheduledAggregation() }) - log.Printf("[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds) if !s.cfg.BackfillEnabled { - log.Printf("[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填") + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填") } } @@ -93,7 +94,7 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro return errors.New("聚合服务未初始化") } if !s.cfg.BackfillEnabled { - log.Printf("[DashboardAggregation] 回填被拒绝: backfill_enabled=false") + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填被拒绝: backfill_enabled=false") return ErrDashboardBackfillDisabled } if !end.After(start) { @@ -110,7 +111,7 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout) defer cancel() if err := s.backfillRange(ctx, start, end); err != nil { - log.Printf("[DashboardAggregation] 回填失败: %v", err) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填失败: %v", err) } }() return nil @@ -141,12 +142,12 @@ func (s *DashboardAggregationService) TriggerRecomputeRange(start, end time.Time return } if !errors.Is(err, errDashboardAggregationRunning) { - log.Printf("[DashboardAggregation] 重新计算失败: %v", err) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 重新计算失败: %v", err) return } time.Sleep(5 * time.Second) } - log.Printf("[DashboardAggregation] 重新计算放弃: 聚合作业持续占用") + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 重新计算放弃: 聚合作业持续占用") }() return nil } @@ -162,7 +163,7 @@ func (s *DashboardAggregationService) recomputeRecentDays() { ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout) defer cancel() if err := s.backfillRange(ctx, start, now); err != nil { - log.Printf("[DashboardAggregation] 启动重算失败: %v", err) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 启动重算失败: %v", err) return } } @@ -177,7 +178,7 @@ func (s *DashboardAggregationService) recomputeRange(ctx context.Context, start, if err := s.repo.RecomputeRange(ctx, start, end); err != nil { return err } - log.Printf("[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)", + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)", start.UTC().Format(time.RFC3339), end.UTC().Format(time.RFC3339), time.Since(jobStart).String(), @@ -198,7 +199,7 @@ func (s *DashboardAggregationService) runScheduledAggregation() { now := time.Now().UTC() last, err := s.repo.GetAggregationWatermark(ctx) if err != nil { - log.Printf("[DashboardAggregation] 读取水位失败: %v", err) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 读取水位失败: %v", err) last = time.Unix(0, 0).UTC() } @@ -216,19 +217,19 @@ func (s *DashboardAggregationService) runScheduledAggregation() { } if err := s.aggregateRange(ctx, start, now); err != nil { - log.Printf("[DashboardAggregation] 聚合失败: %v", err) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合失败: %v", err) return } updateErr := s.repo.UpdateAggregationWatermark(ctx, now) if updateErr != nil { - log.Printf("[DashboardAggregation] 更新水位失败: %v", updateErr) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 更新水位失败: %v", updateErr) } - log.Printf("[DashboardAggregation] 聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)", - start.Format(time.RFC3339), - now.Format(time.RFC3339), - time.Since(jobStart).String(), - updateErr == nil, + slog.Debug("[DashboardAggregation] 聚合完成", + "start", start.Format(time.RFC3339), + "end", now.Format(time.RFC3339), + "duration", time.Since(jobStart).String(), + "watermark_updated", updateErr == nil, ) s.maybeCleanupRetention(ctx, now) @@ -261,9 +262,9 @@ func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, updateErr := s.repo.UpdateAggregationWatermark(ctx, endUTC) if updateErr != nil { - log.Printf("[DashboardAggregation] 更新水位失败: %v", updateErr) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 更新水位失败: %v", updateErr) } - log.Printf("[DashboardAggregation] 回填聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)", + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)", startUTC.Format(time.RFC3339), endUTC.Format(time.RFC3339), time.Since(jobStart).String(), @@ -279,7 +280,7 @@ func (s *DashboardAggregationService) aggregateRange(ctx context.Context, start, return nil } if err := s.repo.EnsureUsageLogsPartitions(ctx, end); err != nil { - log.Printf("[DashboardAggregation] 分区检查失败: %v", err) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 分区检查失败: %v", err) } return s.repo.AggregateRange(ctx, start, end) } @@ -298,11 +299,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff) if aggErr != nil { - log.Printf("[DashboardAggregation] 聚合保留清理失败: %v", aggErr) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合保留清理失败: %v", aggErr) } usageErr := s.repo.CleanupUsageLogs(ctx, usageCutoff) if usageErr != nil { - log.Printf("[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr) } if aggErr == nil && usageErr == nil { s.lastRetentionCleanup.Store(now) diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index cd11923e..2af43386 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -5,11 +5,11 @@ import ( "encoding/json" "errors" "fmt" - "log" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" ) @@ -113,7 +113,7 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D return cached, nil } if err != nil && !errors.Is(err, ErrDashboardStatsCacheMiss) { - log.Printf("[Dashboard] 仪表盘缓存读取失败: %v", err) + logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存读取失败: %v", err) } } @@ -124,22 +124,30 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D return stats, nil } -func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { - trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType) +func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { + trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) if err != nil { return nil, fmt.Errorf("get usage trend with filters: %w", err) } return trend, nil } -func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { - stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType) +func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) if err != nil { return nil, fmt.Errorf("get model stats with filters: %w", err) } return stats, nil } +func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) { + stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + if err != nil { + return nil, fmt.Errorf("get group stats with filters: %w", err) + } + return stats, nil +} + func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) { data, err := s.cache.GetDashboardStats(ctx) if err != nil { @@ -188,7 +196,7 @@ func (s *DashboardService) refreshDashboardStatsAsync() { stats, err := s.fetchDashboardStats(ctx) if err != nil { - log.Printf("[Dashboard] 仪表盘缓存异步刷新失败: %v", err) + logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存异步刷新失败: %v", err) return } s.applyAggregationStatus(ctx, stats) @@ -220,12 +228,12 @@ func (s *DashboardService) saveDashboardStatsCache(ctx context.Context, stats *u } data, err := json.Marshal(entry) if err != nil { - log.Printf("[Dashboard] 仪表盘缓存序列化失败: %v", err) + logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存序列化失败: %v", err) return } if err := s.cache.SetDashboardStats(ctx, string(data), s.cacheTTL); err != nil { - log.Printf("[Dashboard] 仪表盘缓存写入失败: %v", err) + logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存写入失败: %v", err) } } @@ -237,10 +245,10 @@ func (s *DashboardService) evictDashboardStatsCache(reason error) { defer cancel() if err := s.cache.DeleteDashboardStats(cacheCtx); err != nil { - log.Printf("[Dashboard] 仪表盘缓存清理失败: %v", err) + logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存清理失败: %v", err) } if reason != nil { - log.Printf("[Dashboard] 仪表盘缓存异常,已清理: %v", reason) + logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存异常,已清理: %v", reason) } } @@ -271,7 +279,7 @@ func (s *DashboardService) fetchAggregationUpdatedAt(ctx context.Context) time.T } updatedAt, err := s.aggRepo.GetAggregationWatermark(ctx) if err != nil { - log.Printf("[Dashboard] 读取聚合水位失败: %v", err) + logger.LegacyPrintf("service.dashboard", "[Dashboard] 读取聚合水位失败: %v", err) return time.Unix(0, 0).UTC() } if updatedAt.IsZero() { @@ -319,16 +327,16 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end return trend, nil } -func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { - stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs) +func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) { + stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime) if err != nil { return nil, fmt.Errorf("get batch user usage stats: %w", err) } return stats, nil } -func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { - stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs) +func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { + stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime) if err != nil { return nil, fmt.Errorf("get batch api key usage stats: %w", err) } diff --git a/backend/internal/service/data_management_grpc.go b/backend/internal/service/data_management_grpc.go new file mode 100644 index 00000000..aeb3d529 --- /dev/null +++ b/backend/internal/service/data_management_grpc.go @@ -0,0 +1,252 @@ +package service + +import "context" + +type DataManagementPostgresConfig struct { + Host string `json:"host"` + Port int32 `json:"port"` + User string `json:"user"` + Password string `json:"password,omitempty"` + PasswordConfigured bool `json:"password_configured"` + Database string `json:"database"` + SSLMode string `json:"ssl_mode"` + ContainerName string `json:"container_name"` +} + +type DataManagementRedisConfig struct { + Addr string `json:"addr"` + Username string `json:"username"` + Password string `json:"password,omitempty"` + PasswordConfigured bool `json:"password_configured"` + DB int32 `json:"db"` + ContainerName string `json:"container_name"` +} + +type DataManagementS3Config struct { + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key,omitempty"` + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + UseSSL bool `json:"use_ssl"` +} + +type DataManagementConfig struct { + SourceMode string `json:"source_mode"` + BackupRoot string `json:"backup_root"` + SQLitePath string `json:"sqlite_path,omitempty"` + RetentionDays int32 `json:"retention_days"` + KeepLast int32 `json:"keep_last"` + ActivePostgresID string `json:"active_postgres_profile_id"` + ActiveRedisID string `json:"active_redis_profile_id"` + Postgres DataManagementPostgresConfig `json:"postgres"` + Redis DataManagementRedisConfig `json:"redis"` + S3 DataManagementS3Config `json:"s3"` + ActiveS3ProfileID string `json:"active_s3_profile_id"` +} + +type DataManagementTestS3Result struct { + OK bool `json:"ok"` + Message string `json:"message"` +} + +type DataManagementCreateBackupJobInput struct { + BackupType string + UploadToS3 bool + TriggeredBy string + IdempotencyKey string + S3ProfileID string + PostgresID string + RedisID string +} + +type DataManagementListBackupJobsInput struct { + PageSize int32 + PageToken string + Status string + BackupType string +} + +type DataManagementArtifactInfo struct { + LocalPath string `json:"local_path"` + SizeBytes int64 `json:"size_bytes"` + SHA256 string `json:"sha256"` +} + +type DataManagementS3ObjectInfo struct { + Bucket string `json:"bucket"` + Key string `json:"key"` + ETag string `json:"etag"` +} + +type DataManagementBackupJob struct { + JobID string `json:"job_id"` + BackupType string `json:"backup_type"` + Status string `json:"status"` + TriggeredBy string `json:"triggered_by"` + IdempotencyKey string `json:"idempotency_key,omitempty"` + UploadToS3 bool `json:"upload_to_s3"` + S3ProfileID string `json:"s3_profile_id,omitempty"` + PostgresID string `json:"postgres_profile_id,omitempty"` + RedisID string `json:"redis_profile_id,omitempty"` + StartedAt string `json:"started_at,omitempty"` + FinishedAt string `json:"finished_at,omitempty"` + ErrorMessage string `json:"error_message,omitempty"` + Artifact DataManagementArtifactInfo `json:"artifact"` + S3Object DataManagementS3ObjectInfo `json:"s3"` +} + +type DataManagementSourceProfile struct { + SourceType string `json:"source_type"` + ProfileID string `json:"profile_id"` + Name string `json:"name"` + IsActive bool `json:"is_active"` + Config DataManagementSourceConfig `json:"config"` + PasswordConfigured bool `json:"password_configured"` + CreatedAt string `json:"created_at,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +type DataManagementSourceConfig struct { + Host string `json:"host"` + Port int32 `json:"port"` + User string `json:"user"` + Password string `json:"password,omitempty"` + Database string `json:"database"` + SSLMode string `json:"ssl_mode"` + Addr string `json:"addr"` + Username string `json:"username"` + DB int32 `json:"db"` + ContainerName string `json:"container_name"` +} + +type DataManagementCreateSourceProfileInput struct { + SourceType string + ProfileID string + Name string + Config DataManagementSourceConfig + SetActive bool +} + +type DataManagementUpdateSourceProfileInput struct { + SourceType string + ProfileID string + Name string + Config DataManagementSourceConfig +} + +type DataManagementS3Profile struct { + ProfileID string `json:"profile_id"` + Name string `json:"name"` + IsActive bool `json:"is_active"` + S3 DataManagementS3Config `json:"s3"` + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` + CreatedAt string `json:"created_at,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +type DataManagementCreateS3ProfileInput struct { + ProfileID string + Name string + S3 DataManagementS3Config + SetActive bool +} + +type DataManagementUpdateS3ProfileInput struct { + ProfileID string + Name string + S3 DataManagementS3Config +} + +type DataManagementListBackupJobsResult struct { + Items []DataManagementBackupJob `json:"items"` + NextPageToken string `json:"next_page_token,omitempty"` +} + +func (s *DataManagementService) GetConfig(ctx context.Context) (DataManagementConfig, error) { + _ = ctx + return DataManagementConfig{}, s.deprecatedError() +} + +func (s *DataManagementService) UpdateConfig(ctx context.Context, cfg DataManagementConfig) (DataManagementConfig, error) { + _, _ = ctx, cfg + return DataManagementConfig{}, s.deprecatedError() +} + +func (s *DataManagementService) ListSourceProfiles(ctx context.Context, sourceType string) ([]DataManagementSourceProfile, error) { + _, _ = ctx, sourceType + return nil, s.deprecatedError() +} + +func (s *DataManagementService) CreateSourceProfile(ctx context.Context, input DataManagementCreateSourceProfileInput) (DataManagementSourceProfile, error) { + _, _ = ctx, input + return DataManagementSourceProfile{}, s.deprecatedError() +} + +func (s *DataManagementService) UpdateSourceProfile(ctx context.Context, input DataManagementUpdateSourceProfileInput) (DataManagementSourceProfile, error) { + _, _ = ctx, input + return DataManagementSourceProfile{}, s.deprecatedError() +} + +func (s *DataManagementService) DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error { + _, _, _ = ctx, sourceType, profileID + return s.deprecatedError() +} + +func (s *DataManagementService) SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (DataManagementSourceProfile, error) { + _, _, _ = ctx, sourceType, profileID + return DataManagementSourceProfile{}, s.deprecatedError() +} + +func (s *DataManagementService) ValidateS3(ctx context.Context, cfg DataManagementS3Config) (DataManagementTestS3Result, error) { + _, _ = ctx, cfg + return DataManagementTestS3Result{}, s.deprecatedError() +} + +func (s *DataManagementService) ListS3Profiles(ctx context.Context) ([]DataManagementS3Profile, error) { + _ = ctx + return nil, s.deprecatedError() +} + +func (s *DataManagementService) CreateS3Profile(ctx context.Context, input DataManagementCreateS3ProfileInput) (DataManagementS3Profile, error) { + _, _ = ctx, input + return DataManagementS3Profile{}, s.deprecatedError() +} + +func (s *DataManagementService) UpdateS3Profile(ctx context.Context, input DataManagementUpdateS3ProfileInput) (DataManagementS3Profile, error) { + _, _ = ctx, input + return DataManagementS3Profile{}, s.deprecatedError() +} + +func (s *DataManagementService) DeleteS3Profile(ctx context.Context, profileID string) error { + _, _ = ctx, profileID + return s.deprecatedError() +} + +func (s *DataManagementService) SetActiveS3Profile(ctx context.Context, profileID string) (DataManagementS3Profile, error) { + _, _ = ctx, profileID + return DataManagementS3Profile{}, s.deprecatedError() +} + +func (s *DataManagementService) CreateBackupJob(ctx context.Context, input DataManagementCreateBackupJobInput) (DataManagementBackupJob, error) { + _, _ = ctx, input + return DataManagementBackupJob{}, s.deprecatedError() +} + +func (s *DataManagementService) ListBackupJobs(ctx context.Context, input DataManagementListBackupJobsInput) (DataManagementListBackupJobsResult, error) { + _, _ = ctx, input + return DataManagementListBackupJobsResult{}, s.deprecatedError() +} + +func (s *DataManagementService) GetBackupJob(ctx context.Context, jobID string) (DataManagementBackupJob, error) { + _, _ = ctx, jobID + return DataManagementBackupJob{}, s.deprecatedError() +} + +func (s *DataManagementService) deprecatedError() error { + return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()}) +} diff --git a/backend/internal/service/data_management_grpc_test.go b/backend/internal/service/data_management_grpc_test.go new file mode 100644 index 00000000..286eb58d --- /dev/null +++ b/backend/internal/service/data_management_grpc_test.go @@ -0,0 +1,36 @@ +package service + +import ( + "context" + "path/filepath" + "testing" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestDataManagementService_DeprecatedRPCMethods(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(t.TempDir(), "datamanagement.sock") + svc := NewDataManagementServiceWithOptions(socketPath, 0) + + _, err := svc.GetConfig(context.Background()) + assertDeprecatedDataManagementError(t, err, socketPath) + + _, err = svc.CreateBackupJob(context.Background(), DataManagementCreateBackupJobInput{BackupType: "full"}) + assertDeprecatedDataManagementError(t, err, socketPath) + + err = svc.DeleteS3Profile(context.Background(), "s3-default") + assertDeprecatedDataManagementError(t, err, socketPath) +} + +func assertDeprecatedDataManagementError(t *testing.T, err error, socketPath string) { + t.Helper() + + require.Error(t, err) + statusCode, status := infraerrors.ToHTTP(err) + require.Equal(t, 503, statusCode) + require.Equal(t, DataManagementDeprecatedReason, status.Reason) + require.Equal(t, socketPath, status.Metadata["socket_path"]) +} diff --git a/backend/internal/service/data_management_service.go b/backend/internal/service/data_management_service.go new file mode 100644 index 00000000..b525c0fa --- /dev/null +++ b/backend/internal/service/data_management_service.go @@ -0,0 +1,95 @@ +package service + +import ( + "context" + "strings" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +const ( + DefaultDataManagementAgentSocketPath = "/tmp/sub2api-datamanagement.sock" + LegacyBackupAgentSocketPath = "/tmp/sub2api-backup.sock" + + DataManagementDeprecatedReason = "DATA_MANAGEMENT_DEPRECATED" + DataManagementAgentSocketMissingReason = "DATA_MANAGEMENT_AGENT_SOCKET_MISSING" + DataManagementAgentUnavailableReason = "DATA_MANAGEMENT_AGENT_UNAVAILABLE" + + // Deprecated: keep old names for compatibility. + DefaultBackupAgentSocketPath = DefaultDataManagementAgentSocketPath + BackupAgentSocketMissingReason = DataManagementAgentSocketMissingReason + BackupAgentUnavailableReason = DataManagementAgentUnavailableReason +) + +var ( + ErrDataManagementDeprecated = infraerrors.ServiceUnavailable( + DataManagementDeprecatedReason, + "data management feature is deprecated", + ) + ErrDataManagementAgentSocketMissing = infraerrors.ServiceUnavailable( + DataManagementAgentSocketMissingReason, + "data management agent socket is missing", + ) + ErrDataManagementAgentUnavailable = infraerrors.ServiceUnavailable( + DataManagementAgentUnavailableReason, + "data management agent is unavailable", + ) + + // Deprecated: keep old names for compatibility. + ErrBackupAgentSocketMissing = ErrDataManagementAgentSocketMissing + ErrBackupAgentUnavailable = ErrDataManagementAgentUnavailable +) + +type DataManagementAgentHealth struct { + Enabled bool + Reason string + SocketPath string + Agent *DataManagementAgentInfo +} + +type DataManagementAgentInfo struct { + Status string + Version string + UptimeSeconds int64 +} + +type DataManagementService struct { + socketPath string +} + +func NewDataManagementService() *DataManagementService { + return NewDataManagementServiceWithOptions(DefaultDataManagementAgentSocketPath, 500*time.Millisecond) +} + +func NewDataManagementServiceWithOptions(socketPath string, dialTimeout time.Duration) *DataManagementService { + _ = dialTimeout + path := strings.TrimSpace(socketPath) + if path == "" { + path = DefaultDataManagementAgentSocketPath + } + return &DataManagementService{ + socketPath: path, + } +} + +func (s *DataManagementService) SocketPath() string { + if s == nil || strings.TrimSpace(s.socketPath) == "" { + return DefaultDataManagementAgentSocketPath + } + return s.socketPath +} + +func (s *DataManagementService) GetAgentHealth(ctx context.Context) DataManagementAgentHealth { + _ = ctx + return DataManagementAgentHealth{ + Enabled: false, + Reason: DataManagementDeprecatedReason, + SocketPath: s.SocketPath(), + } +} + +func (s *DataManagementService) EnsureAgentEnabled(ctx context.Context) error { + _ = ctx + return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()}) +} diff --git a/backend/internal/service/data_management_service_test.go b/backend/internal/service/data_management_service_test.go new file mode 100644 index 00000000..65489d2e --- /dev/null +++ b/backend/internal/service/data_management_service_test.go @@ -0,0 +1,37 @@ +package service + +import ( + "context" + "path/filepath" + "testing" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestDataManagementService_GetAgentHealth_Deprecated(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(t.TempDir(), "unused.sock") + svc := NewDataManagementServiceWithOptions(socketPath, 0) + health := svc.GetAgentHealth(context.Background()) + + require.False(t, health.Enabled) + require.Equal(t, DataManagementDeprecatedReason, health.Reason) + require.Equal(t, socketPath, health.SocketPath) + require.Nil(t, health.Agent) +} + +func TestDataManagementService_EnsureAgentEnabled_Deprecated(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(t.TempDir(), "unused.sock") + svc := NewDataManagementServiceWithOptions(socketPath, 100) + err := svc.EnsureAgentEnabled(context.Background()) + require.Error(t, err) + + statusCode, status := infraerrors.ToHTTP(err) + require.Equal(t, 503, statusCode) + require.Equal(t, DataManagementDeprecatedReason, status.Reason) + require.Equal(t, socketPath, status.Metadata["socket_path"]) +} diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 0295c23b..46282321 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -24,6 +24,7 @@ const ( PlatformOpenAI = domain.PlatformOpenAI PlatformGemini = domain.PlatformGemini PlatformAntigravity = domain.PlatformAntigravity + PlatformSora = domain.PlatformSora ) // Account type constants @@ -73,11 +74,12 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid" // Setting keys const ( // 注册设置 - SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册 - SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 - SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能 - SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证) - SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册 + SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册 + SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 + SettingKeyRegistrationEmailSuffixWhitelist = "registration_email_suffix_whitelist" // 注册邮箱后缀白名单(JSON 数组) + SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能 + SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证) + SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册 // 邮件服务设置 SettingKeySMTPHost = "smtp_host" // SMTP服务器地址 @@ -103,6 +105,7 @@ const ( SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url" // OEM设置 + SettingKeySoraClientEnabled = "sora_client_enabled" // 是否启用 Sora 客户端(管理员手动控制) SettingKeySiteName = "site_name" // 网站名称 SettingKeySiteLogo = "site_logo" // 网站Logo (base64) SettingKeySiteSubtitle = "site_subtitle" // 网站副标题 @@ -111,12 +114,14 @@ const ( SettingKeyDocURL = "doc_url" // 文档链接 SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src) SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮 - SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示“购买订阅”页面入口 - SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // “购买订阅”页面 URL(作为 iframe src) + SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口 + SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL(作为 iframe src) + SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项(JSON 数组) // 默认配置 - SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 - SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 + SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 + SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 + SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON) // 管理员 API Key SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成) @@ -160,12 +165,46 @@ const ( // SettingKeyOpsAdvancedSettings stores JSON config for ops advanced settings (data retention, aggregation). SettingKeyOpsAdvancedSettings = "ops_advanced_settings" + // SettingKeyOpsRuntimeLogConfig stores JSON config for runtime log settings. + SettingKeyOpsRuntimeLogConfig = "ops_runtime_log_config" + // ========================= // Stream Timeout Handling // ========================= // SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling. SettingKeyStreamTimeoutSettings = "stream_timeout_settings" + + // ========================= + // Sora S3 存储配置 + // ========================= + + SettingKeySoraS3Enabled = "sora_s3_enabled" // 是否启用 Sora S3 存储 + SettingKeySoraS3Endpoint = "sora_s3_endpoint" // S3 端点地址 + SettingKeySoraS3Region = "sora_s3_region" // S3 区域 + SettingKeySoraS3Bucket = "sora_s3_bucket" // S3 存储桶名称 + SettingKeySoraS3AccessKeyID = "sora_s3_access_key_id" // S3 Access Key ID + SettingKeySoraS3SecretAccessKey = "sora_s3_secret_access_key" // S3 Secret Access Key(加密存储) + SettingKeySoraS3Prefix = "sora_s3_prefix" // S3 对象键前缀 + SettingKeySoraS3ForcePathStyle = "sora_s3_force_path_style" // 是否强制 Path Style(兼容 MinIO 等) + SettingKeySoraS3CDNURL = "sora_s3_cdn_url" // CDN 加速 URL(可选) + SettingKeySoraS3Profiles = "sora_s3_profiles" // Sora S3 多配置(JSON) + + // ========================= + // Sora 用户存储配额 + // ========================= + + SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // 新用户默认 Sora 存储配额(字节) + + // ========================= + // Claude Code Version Check + // ========================= + + // SettingKeyMinClaudeCodeVersion 最低 Claude Code 版本号要求 (semver, 如 "2.1.0",空值=不检查) + SettingKeyMinClaudeCodeVersion = "min_claude_code_version" + + // SettingKeyAllowUngroupedKeyScheduling 允许未分组 API Key 调度(默认 false:未分组 Key 返回 403) + SettingKeyAllowUngroupedKeyScheduling = "allow_ungrouped_key_scheduling" ) // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). diff --git a/backend/internal/service/email_queue_service.go b/backend/internal/service/email_queue_service.go index 6c975c69..d8f0a518 100644 --- a/backend/internal/service/email_queue_service.go +++ b/backend/internal/service/email_queue_service.go @@ -3,9 +3,10 @@ package service import ( "context" "fmt" - "log" "sync" "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) // Task type constants @@ -56,7 +57,7 @@ func (s *EmailQueueService) start() { s.wg.Add(1) go s.worker(i) } - log.Printf("[EmailQueue] Started %d workers", s.workers) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Started %d workers", s.workers) } // worker 工作协程 @@ -68,7 +69,7 @@ func (s *EmailQueueService) worker(id int) { case task := <-s.taskChan: s.processTask(id, task) case <-s.stopChan: - log.Printf("[EmailQueue] Worker %d stopping", id) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d stopping", id) return } } @@ -82,18 +83,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) { switch task.TaskType { case TaskTypeVerifyCode: if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil { - log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err) } else { - log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email) } case TaskTypePasswordReset: if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL); err != nil { - log.Printf("[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err) } else { - log.Printf("[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email) } default: - log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType) } } @@ -107,7 +108,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error { select { case s.taskChan <- task: - log.Printf("[EmailQueue] Enqueued verify code task for %s", email) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Enqueued verify code task for %s", email) return nil default: return fmt.Errorf("email queue is full") @@ -125,7 +126,7 @@ func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL strin select { case s.taskChan <- task: - log.Printf("[EmailQueue] Enqueued password reset task for %s", email) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Enqueued password reset task for %s", email) return nil default: return fmt.Errorf("email queue is full") @@ -136,5 +137,5 @@ func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL strin func (s *EmailQueueService) Stop() { close(s.stopChan) s.wg.Wait() - log.Println("[EmailQueue] All workers stopped") + logger.LegacyPrintf("service.email_queue", "%s", "[EmailQueue] All workers stopped") } diff --git a/backend/internal/service/error_passthrough_runtime_test.go b/backend/internal/service/error_passthrough_runtime_test.go index 0a45e57a..7032d15b 100644 --- a/backend/internal/service/error_passthrough_runtime_test.go +++ b/backend/internal/service/error_passthrough_runtime_test.go @@ -76,7 +76,7 @@ func TestOpenAIHandleErrorResponse_NoRuleKeepsDefault(t *testing.T) { } account := &Account{ID: 12, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - _, err := svc.handleErrorResponse(context.Background(), resp, c, account) + _, err := svc.handleErrorResponse(context.Background(), resp, c, account, nil) require.Error(t, err) assert.Equal(t, http.StatusBadGateway, rec.Code) @@ -157,7 +157,7 @@ func TestOpenAIHandleErrorResponse_AppliesRuleFor422(t *testing.T) { } account := &Account{ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - _, err := svc.handleErrorResponse(context.Background(), resp, c, account) + _, err := svc.handleErrorResponse(context.Background(), resp, c, account, nil) require.Error(t, err) assert.Equal(t, http.StatusTeapot, rec.Code) diff --git a/backend/internal/service/error_passthrough_service.go b/backend/internal/service/error_passthrough_service.go index da8c9ccf..26fdf9a7 100644 --- a/backend/internal/service/error_passthrough_service.go +++ b/backend/internal/service/error_passthrough_service.go @@ -2,13 +2,13 @@ package service import ( "context" - "log" "sort" "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) // ErrorPassthroughRepository 定义错误透传规则的数据访问接口 @@ -72,9 +72,9 @@ func NewErrorPassthroughService( // 启动时加载规则到本地缓存 ctx := context.Background() if err := svc.reloadRulesFromDB(ctx); err != nil { - log.Printf("[ErrorPassthroughService] Failed to load rules from DB on startup: %v", err) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to load rules from DB on startup: %v", err) if fallbackErr := svc.refreshLocalCache(ctx); fallbackErr != nil { - log.Printf("[ErrorPassthroughService] Failed to load rules from cache fallback on startup: %v", fallbackErr) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to load rules from cache fallback on startup: %v", fallbackErr) } } @@ -82,7 +82,7 @@ func NewErrorPassthroughService( if cache != nil { cache.SubscribeUpdates(ctx, func() { if err := svc.refreshLocalCache(context.Background()); err != nil { - log.Printf("[ErrorPassthroughService] Failed to refresh cache on notification: %v", err) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to refresh cache on notification: %v", err) } }) } @@ -192,7 +192,7 @@ func (s *ErrorPassthroughService) getCachedRules() []*cachedPassthroughRule { // 如果本地缓存为空,尝试刷新 ctx := context.Background() if err := s.refreshLocalCache(ctx); err != nil { - log.Printf("[ErrorPassthroughService] Failed to refresh cache: %v", err) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to refresh cache: %v", err) return nil } @@ -225,7 +225,7 @@ func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error { // 更新 Redis 缓存 if s.cache != nil { if err := s.cache.Set(ctx, rules); err != nil { - log.Printf("[ErrorPassthroughService] Failed to set cache: %v", err) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to set cache: %v", err) } } @@ -288,13 +288,13 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) { // 先失效缓存,避免后续刷新读到陈旧规则。 if s.cache != nil { if err := s.cache.Invalidate(ctx); err != nil { - log.Printf("[ErrorPassthroughService] Failed to invalidate cache: %v", err) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to invalidate cache: %v", err) } } // 刷新本地缓存 if err := s.reloadRulesFromDB(ctx); err != nil { - log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to refresh local cache: %v", err) // 刷新失败时清空本地缓存,避免继续使用陈旧规则。 s.clearLocalCache() } @@ -302,7 +302,7 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) { // 通知其他实例 if s.cache != nil { if err := s.cache.NotifyUpdate(ctx); err != nil { - log.Printf("[ErrorPassthroughService] Failed to notify cache update: %v", err) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to notify cache update: %v", err) } } } diff --git a/backend/internal/service/gateway_account_selection_test.go b/backend/internal/service/gateway_account_selection_test.go new file mode 100644 index 00000000..0a82fade --- /dev/null +++ b/backend/internal/service/gateway_account_selection_test.go @@ -0,0 +1,206 @@ +//go:build unit + +package service + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// --- helpers --- + +func testTimePtr(t time.Time) *time.Time { return &t } + +func makeAccWithLoad(id int64, priority int, loadRate int, lastUsed *time.Time, accType string) accountWithLoad { + return accountWithLoad{ + account: &Account{ + ID: id, + Priority: priority, + LastUsedAt: lastUsed, + Type: accType, + Schedulable: true, + Status: StatusActive, + }, + loadInfo: &AccountLoadInfo{ + AccountID: id, + CurrentConcurrency: 0, + LoadRate: loadRate, + }, + } +} + +// --- sortAccountsByPriorityAndLastUsed --- + +func TestSortAccountsByPriorityAndLastUsed_ByPriority(t *testing.T) { + now := time.Now() + accounts := []*Account{ + {ID: 1, Priority: 5, LastUsedAt: testTimePtr(now)}, + {ID: 2, Priority: 1, LastUsedAt: testTimePtr(now)}, + {ID: 3, Priority: 3, LastUsedAt: testTimePtr(now)}, + } + sortAccountsByPriorityAndLastUsed(accounts, false) + require.Equal(t, int64(2), accounts[0].ID, "优先级最低的排第一") + require.Equal(t, int64(3), accounts[1].ID) + require.Equal(t, int64(1), accounts[2].ID) +} + +func TestSortAccountsByPriorityAndLastUsed_SamePriorityByLastUsed(t *testing.T) { + now := time.Now() + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: testTimePtr(now)}, + {ID: 2, Priority: 1, LastUsedAt: testTimePtr(now.Add(-1 * time.Hour))}, + {ID: 3, Priority: 1, LastUsedAt: nil}, + } + sortAccountsByPriorityAndLastUsed(accounts, false) + require.Equal(t, int64(3), accounts[0].ID, "nil LastUsedAt 排最前") + require.Equal(t, int64(2), accounts[1].ID, "更早使用的排前面") + require.Equal(t, int64(1), accounts[2].ID) +} + +func TestSortAccountsByPriorityAndLastUsed_PreferOAuth(t *testing.T) { + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, + {ID: 2, Priority: 1, LastUsedAt: nil, Type: AccountTypeOAuth}, + } + sortAccountsByPriorityAndLastUsed(accounts, true) + require.Equal(t, int64(2), accounts[0].ID, "preferOAuth 时 OAuth 账号排前面") +} + +func TestSortAccountsByPriorityAndLastUsed_StableSort(t *testing.T) { + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, + {ID: 2, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, + {ID: 3, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, + } + + // sortAccountsByPriorityAndLastUsed 内部会在同组(Priority+LastUsedAt)内做随机打散, + // 因此这里不再断言“稳定排序”。我们只验证: + // 1) 元素集合不变;2) 多次运行能产生不同的顺序。 + seenFirst := map[int64]bool{} + for i := 0; i < 100; i++ { + cpy := make([]*Account, len(accounts)) + copy(cpy, accounts) + sortAccountsByPriorityAndLastUsed(cpy, false) + seenFirst[cpy[0].ID] = true + + ids := map[int64]bool{} + for _, a := range cpy { + ids[a.ID] = true + } + require.True(t, ids[1] && ids[2] && ids[3]) + } + require.GreaterOrEqual(t, len(seenFirst), 2, "同组账号应能被随机打散") +} + +func TestSortAccountsByPriorityAndLastUsed_MixedPriorityAndTime(t *testing.T) { + now := time.Now() + accounts := []*Account{ + {ID: 1, Priority: 2, LastUsedAt: nil}, + {ID: 2, Priority: 1, LastUsedAt: testTimePtr(now)}, + {ID: 3, Priority: 1, LastUsedAt: testTimePtr(now.Add(-1 * time.Hour))}, + {ID: 4, Priority: 2, LastUsedAt: testTimePtr(now.Add(-2 * time.Hour))}, + } + sortAccountsByPriorityAndLastUsed(accounts, false) + // 优先级1排前:nil < earlier + require.Equal(t, int64(3), accounts[0].ID, "优先级1 + 更早") + require.Equal(t, int64(2), accounts[1].ID, "优先级1 + 现在") + // 优先级2排后:nil < time + require.Equal(t, int64(1), accounts[2].ID, "优先级2 + nil") + require.Equal(t, int64(4), accounts[3].ID, "优先级2 + 有时间") +} + +// --- filterByMinPriority --- + +func TestFilterByMinPriority_Empty(t *testing.T) { + result := filterByMinPriority(nil) + require.Nil(t, result) +} + +func TestFilterByMinPriority_SelectsMinPriority(t *testing.T) { + accounts := []accountWithLoad{ + makeAccWithLoad(1, 5, 10, nil, AccountTypeAPIKey), + makeAccWithLoad(2, 1, 10, nil, AccountTypeAPIKey), + makeAccWithLoad(3, 1, 20, nil, AccountTypeAPIKey), + makeAccWithLoad(4, 2, 10, nil, AccountTypeAPIKey), + } + result := filterByMinPriority(accounts) + require.Len(t, result, 2) + require.Equal(t, int64(2), result[0].account.ID) + require.Equal(t, int64(3), result[1].account.ID) +} + +// --- filterByMinLoadRate --- + +func TestFilterByMinLoadRate_Empty(t *testing.T) { + result := filterByMinLoadRate(nil) + require.Nil(t, result) +} + +func TestFilterByMinLoadRate_SelectsMinLoadRate(t *testing.T) { + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 30, nil, AccountTypeAPIKey), + makeAccWithLoad(2, 1, 10, nil, AccountTypeAPIKey), + makeAccWithLoad(3, 1, 10, nil, AccountTypeAPIKey), + makeAccWithLoad(4, 1, 20, nil, AccountTypeAPIKey), + } + result := filterByMinLoadRate(accounts) + require.Len(t, result, 2) + require.Equal(t, int64(2), result[0].account.ID) + require.Equal(t, int64(3), result[1].account.ID) +} + +// --- selectByLRU --- + +func TestSelectByLRU_Empty(t *testing.T) { + result := selectByLRU(nil, false) + require.Nil(t, result) +} + +func TestSelectByLRU_Single(t *testing.T) { + accounts := []accountWithLoad{makeAccWithLoad(1, 1, 10, nil, AccountTypeAPIKey)} + result := selectByLRU(accounts, false) + require.NotNil(t, result) + require.Equal(t, int64(1), result.account.ID) +} + +func TestSelectByLRU_NilLastUsedAtWins(t *testing.T) { + now := time.Now() + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 10, testTimePtr(now), AccountTypeAPIKey), + makeAccWithLoad(2, 1, 10, nil, AccountTypeAPIKey), + makeAccWithLoad(3, 1, 10, testTimePtr(now.Add(-1*time.Hour)), AccountTypeAPIKey), + } + result := selectByLRU(accounts, false) + require.NotNil(t, result) + require.Equal(t, int64(2), result.account.ID) +} + +func TestSelectByLRU_EarliestTimeWins(t *testing.T) { + now := time.Now() + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 10, testTimePtr(now), AccountTypeAPIKey), + makeAccWithLoad(2, 1, 10, testTimePtr(now.Add(-1*time.Hour)), AccountTypeAPIKey), + makeAccWithLoad(3, 1, 10, testTimePtr(now.Add(-2*time.Hour)), AccountTypeAPIKey), + } + result := selectByLRU(accounts, false) + require.NotNil(t, result) + require.Equal(t, int64(3), result.account.ID) +} + +func TestSelectByLRU_TiePreferOAuth(t *testing.T) { + now := time.Now() + // 账号 1/2 LastUsedAt 相同,且同为最小值。 + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 10, testTimePtr(now), AccountTypeAPIKey), + makeAccWithLoad(2, 1, 10, testTimePtr(now), AccountTypeOAuth), + makeAccWithLoad(3, 1, 10, testTimePtr(now.Add(1*time.Hour)), AccountTypeAPIKey), + } + for i := 0; i < 50; i++ { + result := selectByLRU(accounts, true) + require.NotNil(t, result) + require.Equal(t, AccountTypeOAuth, result.account.Type) + require.Equal(t, int64(2), result.account.ID) + } +} diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_benchmark_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_benchmark_test.go new file mode 100644 index 00000000..37fd709f --- /dev/null +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_benchmark_test.go @@ -0,0 +1,56 @@ +package service + +import "testing" + +func BenchmarkGatewayService_ParseSSEUsage_MessageStart(b *testing.B) { + svc := &GatewayService{} + data := `{"type":"message_start","message":{"usage":{"input_tokens":123,"cache_creation_input_tokens":45,"cache_read_input_tokens":6,"cached_tokens":6,"cache_creation":{"ephemeral_5m_input_tokens":20,"ephemeral_1h_input_tokens":25}}}}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + usage := &ClaudeUsage{} + svc.parseSSEUsage(data, usage) + } +} + +func BenchmarkGatewayService_ParseSSEUsagePassthrough_MessageStart(b *testing.B) { + svc := &GatewayService{} + data := `{"type":"message_start","message":{"usage":{"input_tokens":123,"cache_creation_input_tokens":45,"cache_read_input_tokens":6,"cached_tokens":6,"cache_creation":{"ephemeral_5m_input_tokens":20,"ephemeral_1h_input_tokens":25}}}}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + usage := &ClaudeUsage{} + svc.parseSSEUsagePassthrough(data, usage) + } +} + +func BenchmarkGatewayService_ParseSSEUsage_MessageDelta(b *testing.B) { + svc := &GatewayService{} + data := `{"type":"message_delta","usage":{"output_tokens":456,"cache_creation_input_tokens":30,"cache_read_input_tokens":7,"cached_tokens":7,"cache_creation":{"ephemeral_5m_input_tokens":10,"ephemeral_1h_input_tokens":20}}}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + usage := &ClaudeUsage{} + svc.parseSSEUsage(data, usage) + } +} + +func BenchmarkGatewayService_ParseSSEUsagePassthrough_MessageDelta(b *testing.B) { + svc := &GatewayService{} + data := `{"type":"message_delta","usage":{"output_tokens":456,"cache_creation_input_tokens":30,"cache_read_input_tokens":7,"cached_tokens":7,"cache_creation":{"ephemeral_5m_input_tokens":10,"ephemeral_1h_input_tokens":20}}}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + usage := &ClaudeUsage{} + svc.parseSSEUsagePassthrough(data, usage) + } +} + +func BenchmarkParseClaudeUsageFromResponseBody(b *testing.B) { + body := []byte(`{"id":"msg_123","type":"message","usage":{"input_tokens":123,"output_tokens":456,"cache_creation_input_tokens":45,"cache_read_input_tokens":6,"cached_tokens":6,"cache_creation":{"ephemeral_5m_input_tokens":20,"ephemeral_1h_input_tokens":25}}}`) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = parseClaudeUsageFromResponseBody(body) + } +} diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go new file mode 100644 index 00000000..f8c0ecda --- /dev/null +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -0,0 +1,875 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +type anthropicHTTPUpstreamRecorder struct { + lastReq *http.Request + lastBody []byte + resp *http.Response + err error +} + +func newAnthropicAPIKeyAccountForTest() *Account { + return &Account{ + ID: 201, + Name: "anthropic-apikey-pass-test", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "upstream-anthropic-key", + "base_url": "https://api.anthropic.com", + }, + Extra: map[string]any{ + "anthropic_passthrough": true, + }, + Status: StatusActive, + Schedulable: true, + } +} + +func (u *anthropicHTTPUpstreamRecorder) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + u.lastReq = req + if req != nil && req.Body != nil { + b, _ := io.ReadAll(req.Body) + u.lastBody = b + _ = req.Body.Close() + req.Body = io.NopCloser(bytes.NewReader(b)) + } + if u.err != nil { + return nil, u.err + } + return u.resp, nil +} + +func (u *anthropicHTTPUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { + return u.Do(req, proxyURL, accountID, accountConcurrency) +} + +type streamReadCloser struct { + payload []byte + sent bool + err error +} + +func (r *streamReadCloser) Read(p []byte) (int, error) { + if !r.sent { + r.sent = true + n := copy(p, r.payload) + return n, nil + } + if r.err != nil { + return 0, r.err + } + return 0, io.EOF +} + +func (r *streamReadCloser) Close() error { return nil } + +type failWriteResponseWriter struct { + gin.ResponseWriter +} + +func (w *failWriteResponseWriter) Write(data []byte) (int, error) { + return 0, errors.New("client disconnected") +} + +func (w *failWriteResponseWriter) WriteString(_ string) (int, error) { + return 0, errors.New("client disconnected") +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAndAuthReplacement(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + c.Request.Header.Set("User-Agent", "claude-cli/1.0.0") + c.Request.Header.Set("Authorization", "Bearer inbound-token") + c.Request.Header.Set("X-Api-Key", "inbound-api-key") + c.Request.Header.Set("X-Goog-Api-Key", "inbound-goog-key") + c.Request.Header.Set("Cookie", "secret=1") + c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14") + + body := []byte(`{"model":"claude-3-7-sonnet-20250219","stream":true,"system":[{"type":"text","text":"x-anthropic-billing-header keep"}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`) + parsed := &ParsedRequest{ + Body: body, + Model: "claude-3-7-sonnet-20250219", + Stream: true, + } + + upstreamSSE := strings.Join([]string{ + `data: {"type":"message_start","message":{"usage":{"input_tokens":9,"cached_tokens":7}}}`, + "", + `data: {"type":"message_delta","usage":{"output_tokens":3}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "x-request-id": []string{"rid-anthropic-pass"}, + "Set-Cookie": []string{"secret=upstream"}, + }, + Body: io.NopCloser(strings.NewReader(upstreamSSE)), + }, + } + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + deferredService: &DeferredService{}, + billingCacheService: nil, + } + + account := &Account{ + ID: 101, + Name: "anthropic-apikey-pass", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "upstream-anthropic-key", + "base_url": "https://api.anthropic.com", + "model_mapping": map[string]any{"claude-3-7-sonnet-20250219": "claude-3-haiku-20240307"}, + }, + Extra: map[string]any{ + "anthropic_passthrough": true, + }, + Status: StatusActive, + Schedulable: true, + } + + result, err := svc.Forward(context.Background(), c, account, parsed) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + + require.Equal(t, body, upstream.lastBody, "透传模式不应改写上游请求体") + require.Equal(t, "claude-3-7-sonnet-20250219", gjson.GetBytes(upstream.lastBody, "model").String()) + + require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key")) + require.Empty(t, upstream.lastReq.Header.Get("authorization")) + require.Empty(t, upstream.lastReq.Header.Get("x-goog-api-key")) + require.Empty(t, upstream.lastReq.Header.Get("cookie")) + require.Equal(t, "2023-06-01", upstream.lastReq.Header.Get("anthropic-version")) + require.Equal(t, "interleaved-thinking-2025-05-14", upstream.lastReq.Header.Get("anthropic-beta")) + require.Empty(t, upstream.lastReq.Header.Get("x-stainless-lang"), "API Key 透传不应注入 OAuth 指纹头") + + require.Contains(t, rec.Body.String(), `"cached_tokens":7`) + require.NotContains(t, rec.Body.String(), `"cache_read_input_tokens":7`, "透传输出不应被网关改写") + require.Equal(t, 7, result.Usage.CacheReadInputTokens, "计费 usage 解析应保留 cached_tokens 兼容") + require.Empty(t, rec.Header().Get("Set-Cookie"), "响应头应经过安全过滤") + rawBody, ok := c.Get(OpsUpstreamRequestBodyKey) + require.True(t, ok) + bodyBytes, ok := rawBody.([]byte) + require.True(t, ok, "应以 []byte 形式缓存上游请求体,避免重复 string 拷贝") + require.Equal(t, body, bodyBytes) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBody(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil) + c.Request.Header.Set("Authorization", "Bearer inbound-token") + c.Request.Header.Set("X-Api-Key", "inbound-api-key") + c.Request.Header.Set("Cookie", "secret=1") + + body := []byte(`{"model":"claude-3-5-sonnet-latest","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}],"thinking":{"type":"enabled"}}`) + parsed := &ParsedRequest{ + Body: body, + Model: "claude-3-5-sonnet-latest", + } + + upstreamRespBody := `{"input_tokens":42}` + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "x-request-id": []string{"rid-count"}, + "Set-Cookie": []string{"secret=upstream"}, + }, + Body: io.NopCloser(strings.NewReader(upstreamRespBody)), + }, + } + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + } + + account := &Account{ + ID: 102, + Name: "anthropic-apikey-pass-count", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "upstream-anthropic-key", + "base_url": "https://api.anthropic.com", + "model_mapping": map[string]any{"claude-3-5-sonnet-latest": "claude-3-opus-20240229"}, + }, + Extra: map[string]any{ + "anthropic_passthrough": true, + }, + Status: StatusActive, + Schedulable: true, + } + + err := svc.ForwardCountTokens(context.Background(), c, account, parsed) + require.NoError(t, err) + + require.Equal(t, body, upstream.lastBody, "count_tokens 透传模式不应改写请求体") + require.Equal(t, "claude-3-5-sonnet-latest", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key")) + require.Empty(t, upstream.lastReq.Header.Get("authorization")) + require.Empty(t, upstream.lastReq.Header.Get("cookie")) + require.Equal(t, http.StatusOK, rec.Code) + require.JSONEq(t, upstreamRespBody, rec.Body.String()) + require.Empty(t, rec.Header().Get("Set-Cookie")) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotError(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + statusCode int + respBody string + wantPassthrough bool + }{ + { + name: "404 endpoint not found passes through as 404", + statusCode: http.StatusNotFound, + respBody: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"not_found_error"}}`, + wantPassthrough: true, + }, + { + name: "404 generic not found does not passthrough", + statusCode: http.StatusNotFound, + respBody: `{"error":{"message":"resource not found","type":"not_found_error"}}`, + wantPassthrough: false, + }, + { + name: "400 Invalid URL does not passthrough", + statusCode: http.StatusBadRequest, + respBody: `{"error":{"message":"Invalid URL (POST /v1/messages/count_tokens)","type":"invalid_request_error"}}`, + wantPassthrough: false, + }, + { + name: "400 model error does not passthrough", + statusCode: http.StatusBadRequest, + respBody: `{"error":{"message":"model not found: claude-unknown","type":"invalid_request_error"}}`, + wantPassthrough: false, + }, + { + name: "500 internal error does not passthrough", + statusCode: http.StatusInternalServerError, + respBody: `{"error":{"message":"internal error","type":"api_error"}}`, + wantPassthrough: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil) + + body := []byte(`{"model":"claude-sonnet-4-5-20250929","messages":[{"role":"user","content":"hi"}]}`) + parsed := &ParsedRequest{Body: body, Model: "claude-sonnet-4-5-20250929"} + + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: tt.statusCode, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(tt.respBody)), + }, + } + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }, + httpUpstream: upstream, + rateLimitService: nil, + } + + account := &Account{ + ID: 200, + Name: "proxy-acc", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-proxy", + "base_url": "https://proxy.example.com", + }, + Extra: map[string]any{"anthropic_passthrough": true}, + Status: StatusActive, + Schedulable: true, + } + + err := svc.ForwardCountTokens(context.Background(), c, account, parsed) + + if tt.wantPassthrough { + // 返回 nil(不记录为错误),HTTP 状态码 404 + Anthropic 错误体 + require.NoError(t, err) + require.Equal(t, http.StatusNotFound, rec.Code) + var errResp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &errResp)) + require.Equal(t, "error", errResp["type"]) + errObj, ok := errResp["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "not_found_error", errObj["type"]) + } else { + require.Error(t, err) + require.Equal(t, tt.statusCode, rec.Code) + } + }) + } +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_BuildRequestRejectsInvalidBaseURL(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{ + Enabled: false, + }, + }, + }, + } + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "k", + "base_url": "://invalid-url", + }, + } + + _, err := svc.buildUpstreamRequestAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "k") + require.Error(t, err) +} + +func TestGatewayService_AnthropicOAuth_NotAffectedByAPIKeyPassthroughToggle(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }, + } + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "anthropic_passthrough": true, + }, + } + + require.False(t, account.IsAnthropicAPIKeyPassthroughEnabled()) + + req, err := svc.buildUpstreamRequest(context.Background(), c, account, []byte(`{"model":"claude-3-7-sonnet-20250219"}`), "oauth-token", "oauth", "claude-3-7-sonnet-20250219", true, false) + require.NoError(t, err) + require.Equal(t, "Bearer oauth-token", req.Header.Get("authorization")) + require.Contains(t, req.Header.Get("anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta") +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAfterClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Use a canceled context recorder to simulate client disconnect behavior. + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + ctx, cancel := context.WithCancel(req.Context()) + cancel() + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"message_start","message":{"usage":{"input_tokens":11}}}`, + "", + `data: {"type":"message_delta","usage":{"output_tokens":5}}`, + "", + "data: [DONE]", + "", + }, "\n"))), + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "claude-3-7-sonnet-20250219") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 11, result.usage.InputTokens) + require.Equal(t, 5, result.usage.OutputTokens) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + body := []byte(`{"model":"claude-3-5-sonnet-latest","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`) + upstreamJSON := `{"id":"msg_1","type":"message","usage":{"input_tokens":12,"output_tokens":7,"cache_creation":{"ephemeral_5m_input_tokens":2,"ephemeral_1h_input_tokens":3},"cached_tokens":4}}` + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "x-request-id": []string{"rid-nonstream"}, + }, + Body: io.NopCloser(strings.NewReader(upstreamJSON)), + }, + } + svc := &GatewayService{ + cfg: &config.Config{}, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + } + + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", false, time.Now()) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 12, result.Usage.InputTokens) + require.Equal(t, 7, result.Usage.OutputTokens) + require.Equal(t, 5, result.Usage.CacheCreationInputTokens) + require.Equal(t, 4, result.Usage.CacheReadInputTokens) + require.Equal(t, upstreamJSON, rec.Body.String()) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_InvalidTokenType(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + account := &Account{ + ID: 202, + Name: "anthropic-oauth", + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "oauth-token", + }, + } + svc := &GatewayService{} + + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", false, time.Now()) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "requires apikey token") +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_UpstreamRequestError(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + upstream := &anthropicHTTPUpstreamRecorder{ + err: errors.New("dial tcp timeout"), + } + svc := &GatewayService{ + cfg: &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{Enabled: false}, + }, + }, + httpUpstream: upstream, + } + account := newAnthropicAPIKeyAccountForTest() + + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", false, time.Now()) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "upstream request failed") + require.Equal(t, http.StatusBadGateway, rec.Code) + rawBody, ok := c.Get(OpsUpstreamRequestBodyKey) + require.True(t, ok) + _, ok = rawBody.([]byte) + require.True(t, ok) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_EmptyResponseBody(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"x-request-id": []string{"rid-empty-body"}}, + Body: nil, + }, + } + svc := &GatewayService{ + cfg: &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{Enabled: false}, + }, + }, + httpUpstream: upstream, + } + + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", false, time.Now()) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "empty response") +} + +func TestExtractAnthropicSSEDataLine(t *testing.T) { + t.Run("valid data line with spaces", func(t *testing.T) { + data, ok := extractAnthropicSSEDataLine("data: {\"type\":\"message_start\"}") + require.True(t, ok) + require.Equal(t, `{"type":"message_start"}`, data) + }) + + t.Run("non data line", func(t *testing.T) { + data, ok := extractAnthropicSSEDataLine("event: message_start") + require.False(t, ok) + require.Empty(t, data) + }) +} + +func TestGatewayService_ParseSSEUsagePassthrough_MessageStartFallbacks(t *testing.T) { + svc := &GatewayService{} + usage := &ClaudeUsage{} + data := `{"type":"message_start","message":{"usage":{"input_tokens":12,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cached_tokens":9,"cache_creation":{"ephemeral_5m_input_tokens":3,"ephemeral_1h_input_tokens":4}}}}` + + svc.parseSSEUsagePassthrough(data, usage) + + require.Equal(t, 12, usage.InputTokens) + require.Equal(t, 9, usage.CacheReadInputTokens, "应兼容 cached_tokens 字段") + require.Equal(t, 7, usage.CacheCreationInputTokens, "聚合字段为空时应从 5m/1h 明细回填") + require.Equal(t, 3, usage.CacheCreation5mTokens) + require.Equal(t, 4, usage.CacheCreation1hTokens) +} + +func TestGatewayService_ParseSSEUsagePassthrough_MessageDeltaSelectiveOverwrite(t *testing.T) { + svc := &GatewayService{} + usage := &ClaudeUsage{ + InputTokens: 10, + CacheCreation5mTokens: 2, + CacheCreation1hTokens: 6, + } + data := `{"type":"message_delta","usage":{"input_tokens":0,"output_tokens":5,"cache_creation_input_tokens":8,"cache_read_input_tokens":0,"cached_tokens":11,"cache_creation":{"ephemeral_5m_input_tokens":1,"ephemeral_1h_input_tokens":0}}}` + + svc.parseSSEUsagePassthrough(data, usage) + + require.Equal(t, 10, usage.InputTokens, "message_delta 中 0 值不应覆盖已有 input_tokens") + require.Equal(t, 5, usage.OutputTokens) + require.Equal(t, 8, usage.CacheCreationInputTokens) + require.Equal(t, 11, usage.CacheReadInputTokens, "cache_read_input_tokens 为空时应回退到 cached_tokens") + require.Equal(t, 1, usage.CacheCreation5mTokens) + require.Equal(t, 6, usage.CacheCreation1hTokens, "message_delta 中 0 值不应覆盖已有 1h 明细") +} + +func TestGatewayService_ParseSSEUsagePassthrough_NoopCases(t *testing.T) { + svc := &GatewayService{} + + usage := &ClaudeUsage{InputTokens: 3} + svc.parseSSEUsagePassthrough("", usage) + require.Equal(t, 3, usage.InputTokens) + + svc.parseSSEUsagePassthrough("[DONE]", usage) + require.Equal(t, 3, usage.InputTokens) + + svc.parseSSEUsagePassthrough("not-json", usage) + require.Equal(t, 3, usage.InputTokens) + + // nil usage 不应 panic + svc.parseSSEUsagePassthrough(`{"type":"message_start"}`, nil) +} + +func TestGatewayService_ParseSSEUsagePassthrough_FallbackFromUsageNode(t *testing.T) { + svc := &GatewayService{} + usage := &ClaudeUsage{} + data := `{"type":"content_block_delta","usage":{"cached_tokens":6,"cache_creation":{"ephemeral_5m_input_tokens":2,"ephemeral_1h_input_tokens":1}}}` + + svc.parseSSEUsagePassthrough(data, usage) + + require.Equal(t, 6, usage.CacheReadInputTokens) + require.Equal(t, 3, usage.CacheCreationInputTokens) +} + +func TestParseClaudeUsageFromResponseBody(t *testing.T) { + t.Run("empty or missing usage", func(t *testing.T) { + got := parseClaudeUsageFromResponseBody(nil) + require.NotNil(t, got) + require.Equal(t, 0, got.InputTokens) + + got = parseClaudeUsageFromResponseBody([]byte(`{"id":"x"}`)) + require.NotNil(t, got) + require.Equal(t, 0, got.OutputTokens) + }) + + t.Run("parse all usage fields and fallback", func(t *testing.T) { + body := []byte(`{"usage":{"input_tokens":21,"output_tokens":34,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cached_tokens":13,"cache_creation":{"ephemeral_5m_input_tokens":5,"ephemeral_1h_input_tokens":8}}}`) + got := parseClaudeUsageFromResponseBody(body) + require.Equal(t, 21, got.InputTokens) + require.Equal(t, 34, got.OutputTokens) + require.Equal(t, 13, got.CacheReadInputTokens, "cache_read_input_tokens 为空时应回退 cached_tokens") + require.Equal(t, 13, got.CacheCreationInputTokens, "聚合字段为空时应由 5m/1h 回填") + require.Equal(t, 5, got.CacheCreation5mTokens) + require.Equal(t, 8, got.CacheCreation1hTokens) + }) + + t.Run("keep explicit aggregate values", func(t *testing.T) { + body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"cache_creation_input_tokens":9,"cache_read_input_tokens":7,"cached_tokens":99,"cache_creation":{"ephemeral_5m_input_tokens":4,"ephemeral_1h_input_tokens":5}}}`) + got := parseClaudeUsageFromResponseBody(body) + require.Equal(t, 9, got.CacheCreationInputTokens, "已显式提供聚合字段时不应被明细覆盖") + require.Equal(t, 7, got.CacheReadInputTokens, "已显式提供 cache_read_input_tokens 时不应回退 cached_tokens") + }) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingErrTooLong(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: 32, + }, + }, + } + + // Scanner 初始缓冲为 64KB,构造更长单行触发 bufio.ErrTooLong。 + longLine := "data: " + strings.Repeat("x", 80*1024) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(longLine)), + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 2}, time.Now(), "claude-3-7-sonnet-20250219") + require.Error(t, err) + require.ErrorIs(t, err, bufio.ErrTooLong) + require.NotNil(t, result) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingDataIntervalTimeout(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 1, + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: pr, + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 5}, time.Now(), "claude-3-7-sonnet-20250219") + _ = pw.Close() + _ = pr.Close() + + require.Error(t, err) + require.Contains(t, err.Error(), "stream data interval timeout") + require.NotNil(t, result) + require.False(t, result.clientDisconnect) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingReadError(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: &streamReadCloser{ + err: io.ErrUnexpectedEOF, + }, + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 6}, time.Now(), "claude-3-7-sonnet-20250219") + require.Error(t, err) + require.Contains(t, err.Error(), "stream read error") + require.NotNil(t, result) + require.False(t, result.clientDisconnect) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + c.Writer = &failWriteResponseWriter{ResponseWriter: c.Writer} + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 1, + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: pr, + } + + done := make(chan struct{}) + go func() { + defer close(done) + _, _ = pw.Write([]byte(`data: {"type":"message_start","message":{"usage":{"input_tokens":9}}}` + "\n")) + // 保持上游连接静默,触发数据间隔超时分支。 + time.Sleep(1500 * time.Millisecond) + _ = pw.Close() + }() + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 7}, time.Now(), "claude-3-7-sonnet-20250219") + _ = pr.Close() + <-done + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.Equal(t, 9, result.usage.InputTokens) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: &streamReadCloser{ + err: context.Canceled, + }, + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAfterClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + c.Writer = &failWriteResponseWriter{ResponseWriter: c.Writer} + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: &streamReadCloser{ + payload: []byte(`data: {"type":"message_start","message":{"usage":{"input_tokens":8}}}` + "\n\n"), + err: io.ErrUnexpectedEOF, + }, + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.Equal(t, 8, result.usage.InputTokens) +} diff --git a/backend/internal/service/gateway_beta_test.go b/backend/internal/service/gateway_beta_test.go index dd58c183..21a1faa4 100644 --- a/backend/internal/service/gateway_beta_test.go +++ b/backend/internal/service/gateway_beta_test.go @@ -3,6 +3,8 @@ package service import ( "testing" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/stretchr/testify/require" ) @@ -21,3 +23,180 @@ func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) { ) require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got) } + +func TestStripBetaTokens(t *testing.T) { + tests := []struct { + name string + header string + tokens []string + want string + }{ + { + name: "single token in middle", + header: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14", + tokens: []string{"context-1m-2025-08-07"}, + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "single token at start", + header: "context-1m-2025-08-07,oauth-2025-04-20,interleaved-thinking-2025-05-14", + tokens: []string{"context-1m-2025-08-07"}, + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "single token at end", + header: "oauth-2025-04-20,interleaved-thinking-2025-05-14,context-1m-2025-08-07", + tokens: []string{"context-1m-2025-08-07"}, + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "token not present", + header: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + tokens: []string{"context-1m-2025-08-07"}, + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "empty header", + header: "", + tokens: []string{"context-1m-2025-08-07"}, + want: "", + }, + { + name: "with spaces", + header: "oauth-2025-04-20, context-1m-2025-08-07 , interleaved-thinking-2025-05-14", + tokens: []string{"context-1m-2025-08-07"}, + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "only token", + header: "context-1m-2025-08-07", + tokens: []string{"context-1m-2025-08-07"}, + want: "", + }, + { + name: "nil tokens", + header: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + tokens: nil, + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "multiple tokens removed", + header: "oauth-2025-04-20,context-1m-2025-08-07,interleaved-thinking-2025-05-14,fast-mode-2026-02-01", + tokens: []string{"context-1m-2025-08-07", "fast-mode-2026-02-01"}, + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + { + name: "DroppedBetas removes both context-1m and fast-mode", + header: "oauth-2025-04-20,context-1m-2025-08-07,fast-mode-2026-02-01,interleaved-thinking-2025-05-14", + tokens: claude.DroppedBetas, + want: "oauth-2025-04-20,interleaved-thinking-2025-05-14", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := stripBetaTokens(tt.header, tt.tokens) + require.Equal(t, tt.want, got) + }) + } +} + +func TestMergeAnthropicBetaDropping_Context1M(t *testing.T) { + required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"} + incoming := "context-1m-2025-08-07,foo-beta,oauth-2025-04-20" + drop := map[string]struct{}{"context-1m-2025-08-07": {}} + + got := mergeAnthropicBetaDropping(required, incoming, drop) + require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got) + require.NotContains(t, got, "context-1m-2025-08-07") +} + +func TestMergeAnthropicBetaDropping_DroppedBetas(t *testing.T) { + required := []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"} + incoming := "context-1m-2025-08-07,fast-mode-2026-02-01,foo-beta,oauth-2025-04-20" + drop := droppedBetaSet() + + got := mergeAnthropicBetaDropping(required, incoming, drop) + require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo-beta", got) + require.NotContains(t, got, "context-1m-2025-08-07") + require.NotContains(t, got, "fast-mode-2026-02-01") +} + +func TestDroppedBetaSet(t *testing.T) { + // Base set contains DroppedBetas + base := droppedBetaSet() + require.Contains(t, base, claude.BetaContext1M) + require.Contains(t, base, claude.BetaFastMode) + require.Len(t, base, len(claude.DroppedBetas)) + + // With extra tokens + extended := droppedBetaSet(claude.BetaClaudeCode) + require.Contains(t, extended, claude.BetaContext1M) + require.Contains(t, extended, claude.BetaFastMode) + require.Contains(t, extended, claude.BetaClaudeCode) + require.Len(t, extended, len(claude.DroppedBetas)+1) +} + +func TestBuildBetaTokenSet(t *testing.T) { + got := buildBetaTokenSet([]string{"foo", "", "bar", "foo"}) + require.Len(t, got, 2) + require.Contains(t, got, "foo") + require.Contains(t, got, "bar") + require.NotContains(t, got, "") + + empty := buildBetaTokenSet(nil) + require.Empty(t, empty) +} + +func TestStripBetaTokensWithSet_EmptyDropSet(t *testing.T) { + header := "oauth-2025-04-20,interleaved-thinking-2025-05-14" + got := stripBetaTokensWithSet(header, map[string]struct{}{}) + require.Equal(t, header, got) +} + +func TestIsCountTokensUnsupported404(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + want bool + }{ + { + name: "exact endpoint not found", + statusCode: 404, + body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"not_found_error"}}`, + want: true, + }, + { + name: "contains count_tokens and not found", + statusCode: 404, + body: `{"error":{"message":"count_tokens route not found","type":"not_found_error"}}`, + want: true, + }, + { + name: "generic 404", + statusCode: 404, + body: `{"error":{"message":"resource not found","type":"not_found_error"}}`, + want: false, + }, + { + name: "404 with empty error message", + statusCode: 404, + body: `{"error":{"message":"","type":"not_found_error"}}`, + want: false, + }, + { + name: "non-404 status", + statusCode: 400, + body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"invalid_request_error"}}`, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isCountTokensUnsupported404(tt.statusCode, []byte(tt.body)) + require.Equal(t, tt.want, got) + }) + } +} diff --git a/backend/internal/service/gateway_group_isolation_test.go b/backend/internal/service/gateway_group_isolation_test.go new file mode 100644 index 00000000..00508f0e --- /dev/null +++ b/backend/internal/service/gateway_group_isolation_test.go @@ -0,0 +1,363 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// ============================================================================ +// Part 1: isAccountInGroup 单元测试 +// ============================================================================ + +func TestIsAccountInGroup(t *testing.T) { + svc := &GatewayService{} + groupID100 := int64(100) + groupID200 := int64(200) + + tests := []struct { + name string + account *Account + groupID *int64 + expected bool + }{ + // groupID == nil(无分组 API Key) + { + "nil_groupID_ungrouped_account_nil_groups", + &Account{ID: 1, AccountGroups: nil}, + nil, true, + }, + { + "nil_groupID_ungrouped_account_empty_slice", + &Account{ID: 2, AccountGroups: []AccountGroup{}}, + nil, true, + }, + { + "nil_groupID_grouped_account_single", + &Account{ID: 3, AccountGroups: []AccountGroup{{GroupID: 100}}}, + nil, false, + }, + { + "nil_groupID_grouped_account_multiple", + &Account{ID: 4, AccountGroups: []AccountGroup{{GroupID: 100}, {GroupID: 200}}}, + nil, false, + }, + // groupID != nil(有分组 API Key) + { + "with_groupID_account_in_group", + &Account{ID: 5, AccountGroups: []AccountGroup{{GroupID: 100}}}, + &groupID100, true, + }, + { + "with_groupID_account_not_in_group", + &Account{ID: 6, AccountGroups: []AccountGroup{{GroupID: 200}}}, + &groupID100, false, + }, + { + "with_groupID_ungrouped_account", + &Account{ID: 7, AccountGroups: nil}, + &groupID100, false, + }, + { + "with_groupID_multi_group_account_match_one", + &Account{ID: 8, AccountGroups: []AccountGroup{{GroupID: 100}, {GroupID: 200}}}, + &groupID200, true, + }, + { + "with_groupID_multi_group_account_no_match", + &Account{ID: 9, AccountGroups: []AccountGroup{{GroupID: 300}, {GroupID: 400}}}, + &groupID100, false, + }, + // 防御性边界 + { + "nil_account_nil_groupID", + nil, + nil, false, + }, + { + "nil_account_with_groupID", + nil, + &groupID100, false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.isAccountInGroup(tt.account, tt.groupID) + require.Equal(t, tt.expected, got, "isAccountInGroup 结果不符预期") + }) + } +} + +// ============================================================================ +// Part 2: 分组隔离端到端调度测试 +// ============================================================================ + +// groupAwareMockAccountRepo 嵌入 mockAccountRepoForPlatform,覆写分组隔离相关方法。 +// allAccounts 存储所有账号,分组查询方法按 AccountGroups 字段进行真实过滤。 +type groupAwareMockAccountRepo struct { + *mockAccountRepoForPlatform + allAccounts []Account +} + +// ListSchedulableUngroupedByPlatform 仅返回未分组账号(AccountGroups 为空) +func (m *groupAwareMockAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + var result []Account + for _, acc := range m.allAccounts { + if acc.Platform == platform && acc.IsSchedulable() && len(acc.AccountGroups) == 0 { + result = append(result, acc) + } + } + return result, nil +} + +// ListSchedulableUngroupedByPlatforms 仅返回未分组账号(多平台版本) +func (m *groupAwareMockAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + platformSet := make(map[string]bool, len(platforms)) + for _, p := range platforms { + platformSet[p] = true + } + var result []Account + for _, acc := range m.allAccounts { + if platformSet[acc.Platform] && acc.IsSchedulable() && len(acc.AccountGroups) == 0 { + result = append(result, acc) + } + } + return result, nil +} + +// ListSchedulableByGroupIDAndPlatform 返回属于指定分组的账号 +func (m *groupAwareMockAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) { + var result []Account + for _, acc := range m.allAccounts { + if acc.Platform == platform && acc.IsSchedulable() && accountBelongsToGroup(acc, groupID) { + result = append(result, acc) + } + } + return result, nil +} + +// ListSchedulableByGroupIDAndPlatforms 返回属于指定分组的账号(多平台版本) +func (m *groupAwareMockAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) { + platformSet := make(map[string]bool, len(platforms)) + for _, p := range platforms { + platformSet[p] = true + } + var result []Account + for _, acc := range m.allAccounts { + if platformSet[acc.Platform] && acc.IsSchedulable() && accountBelongsToGroup(acc, groupID) { + result = append(result, acc) + } + } + return result, nil +} + +// accountBelongsToGroup 检查账号是否属于指定分组 +func accountBelongsToGroup(acc Account, groupID int64) bool { + for _, ag := range acc.AccountGroups { + if ag.GroupID == groupID { + return true + } + } + return false +} + +// Verify interface implementation +var _ AccountRepository = (*groupAwareMockAccountRepo)(nil) + +// newGroupAwareMockRepo 创建分组感知的 mock repo +func newGroupAwareMockRepo(accounts []Account) *groupAwareMockAccountRepo { + byID := make(map[int64]*Account, len(accounts)) + for i := range accounts { + byID[accounts[i].ID] = &accounts[i] + } + return &groupAwareMockAccountRepo{ + mockAccountRepoForPlatform: &mockAccountRepoForPlatform{ + accounts: accounts, + accountsByID: byID, + }, + allAccounts: accounts, + } +} + +func TestGroupIsolation_UngroupedKey_ShouldNotScheduleGroupedAccounts(t *testing.T) { + // 场景:无分组 API Key(groupID=nil),池中只有已分组账号 → 应返回错误 + ctx := context.Background() + + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 100}}}, + {ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 200}}}, + } + repo := newGroupAwareMockRepo(accounts) + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI) + require.Error(t, err, "无分组 Key 不应调度到已分组账号") + require.Nil(t, acc) +} + +func TestGroupIsolation_GroupedKey_ShouldNotScheduleUngroupedAccounts(t *testing.T) { + // 场景:有分组 API Key(groupID=100),池中只有未分组账号 → 应返回错误 + ctx := context.Background() + groupID := int64(100) + + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: nil}, + {ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{}}, + } + repo := newGroupAwareMockRepo(accounts) + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", "", nil, PlatformOpenAI) + require.Error(t, err, "有分组 Key 不应调度到未分组账号") + require.Nil(t, acc) +} + +func TestGroupIsolation_UngroupedKey_ShouldOnlyScheduleUngroupedAccounts(t *testing.T) { + // 场景:无分组 API Key(groupID=nil),池中有未分组和已分组账号 → 应只选中未分组的 + ctx := context.Background() + + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 100}}}, // 已分组,不应被选中 + {ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true, + AccountGroups: nil}, // 未分组,应被选中 + {ID: 3, Platform: PlatformOpenAI, Priority: 3, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 200}}}, // 已分组,不应被选中 + } + repo := newGroupAwareMockRepo(accounts) + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI) + require.NoError(t, err, "应成功调度未分组账号") + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID, "应选中未分组的账号 ID=2") +} + +func TestGroupIsolation_GroupedKey_ShouldOnlyScheduleMatchingGroupAccounts(t *testing.T) { + // 场景:有分组 API Key(groupID=100),池中有未分组和多个分组账号 → 应只选中分组 100 内的 + ctx := context.Background() + groupID := int64(100) + + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: nil}, // 未分组,不应被选中 + {ID: 2, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 200}}}, // 属于分组 200,不应被选中 + {ID: 3, Platform: PlatformOpenAI, Priority: 3, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 100}}}, // 属于分组 100,应被选中 + } + repo := newGroupAwareMockRepo(accounts) + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, &groupID, "", "", nil, PlatformOpenAI) + require.NoError(t, err, "应成功调度分组内账号") + require.NotNil(t, acc) + require.Equal(t, int64(3), acc.ID, "应选中分组 100 内的账号 ID=3") +} + +// ============================================================================ +// Part 3: SimpleMode 旁路测试 +// ============================================================================ + +func TestGroupIsolation_SimpleMode_SkipsGroupIsolation(t *testing.T) { + // SimpleMode 应跳过分组隔离,使用 ListSchedulableByPlatform 返回所有账号。 + // 测试非 useMixed 路径(platform=openai,不会触发 mixed 调度逻辑)。 + ctx := context.Background() + + // 混合未分组和已分组账号,SimpleMode 下应全部可调度 + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 2, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 100}}}, // 已分组 + {ID: 2, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: nil}, // 未分组 + } + + // 使用基础 mock(ListSchedulableByPlatform 返回所有匹配平台的账号,不做分组过滤) + byID := make(map[int64]*Account, len(accounts)) + for i := range accounts { + byID[accounts[i].ID] = &accounts[i] + } + repo := &mockAccountRepoForPlatform{ + accounts: accounts, + accountsByID: byID, + } + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: &config.Config{RunMode: config.RunModeSimple}, + } + + // groupID=nil 时,SimpleMode 应使用 ListSchedulableByPlatform(不过滤分组) + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI) + require.NoError(t, err, "SimpleMode 应跳过分组隔离直接返回账号") + require.NotNil(t, acc) + // 应选择优先级最高的账号(Priority=1, ID=2),即使它未分组 + require.Equal(t, int64(2), acc.ID, "SimpleMode 应按优先级选择,不考虑分组") +} + +func TestGroupIsolation_SimpleMode_GroupedAccountAlsoSchedulable(t *testing.T) { + // SimpleMode + groupID=nil 时,已分组账号也应该可被调度 + ctx := context.Background() + + // 只有已分组账号,在 standard 模式下 groupID=nil 会报错,但 simple 模式应正常 + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Priority: 1, Status: StatusActive, Schedulable: true, + AccountGroups: []AccountGroup{{GroupID: 100}}}, + } + + byID := make(map[int64]*Account, len(accounts)) + for i := range accounts { + byID[accounts[i].ID] = &accounts[i] + } + repo := &mockAccountRepoForPlatform{ + accounts: accounts, + accountsByID: byID, + } + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: &config.Config{RunMode: config.RunModeSimple}, + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "", nil, PlatformOpenAI) + require.NoError(t, err, "SimpleMode 下已分组账号也应可调度") + require.NotNil(t, acc) + require.Equal(t, int64(1), acc.ID, "SimpleMode 应能调度已分组账号") +} diff --git a/backend/internal/service/gateway_hotpath_optimization_test.go b/backend/internal/service/gateway_hotpath_optimization_test.go new file mode 100644 index 00000000..161c4ba4 --- /dev/null +++ b/backend/internal/service/gateway_hotpath_optimization_test.go @@ -0,0 +1,786 @@ +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + gocache "github.com/patrickmn/go-cache" + "github.com/stretchr/testify/require" +) + +type userGroupRateRepoHotpathStub struct { + UserGroupRateRepository + + rate *float64 + err error + wait <-chan struct{} + calls atomic.Int64 +} + +func (s *userGroupRateRepoHotpathStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { + s.calls.Add(1) + if s.wait != nil { + <-s.wait + } + if s.err != nil { + return nil, s.err + } + return s.rate, nil +} + +type usageLogWindowBatchRepoStub struct { + UsageLogRepository + + batchResult map[int64]*usagestats.AccountStats + batchErr error + batchCalls atomic.Int64 + + singleResult map[int64]*usagestats.AccountStats + singleErr error + singleCalls atomic.Int64 +} + +func (s *usageLogWindowBatchRepoStub) GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) { + s.batchCalls.Add(1) + if s.batchErr != nil { + return nil, s.batchErr + } + out := make(map[int64]*usagestats.AccountStats, len(accountIDs)) + for _, id := range accountIDs { + if stats, ok := s.batchResult[id]; ok { + out[id] = stats + } + } + return out, nil +} + +func (s *usageLogWindowBatchRepoStub) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) { + s.singleCalls.Add(1) + if s.singleErr != nil { + return nil, s.singleErr + } + if stats, ok := s.singleResult[accountID]; ok { + return stats, nil + } + return &usagestats.AccountStats{}, nil +} + +type sessionLimitCacheHotpathStub struct { + SessionLimitCache + + batchData map[int64]float64 + batchErr error + + setData map[int64]float64 + setErr error +} + +func (s *sessionLimitCacheHotpathStub) GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error) { + if s.batchErr != nil { + return nil, s.batchErr + } + out := make(map[int64]float64, len(accountIDs)) + for _, id := range accountIDs { + if v, ok := s.batchData[id]; ok { + out[id] = v + } + } + return out, nil +} + +func (s *sessionLimitCacheHotpathStub) SetWindowCost(ctx context.Context, accountID int64, cost float64) error { + if s.setErr != nil { + return s.setErr + } + if s.setData == nil { + s.setData = make(map[int64]float64) + } + s.setData[accountID] = cost + return nil +} + +type modelsListAccountRepoStub struct { + AccountRepository + + byGroup map[int64][]Account + all []Account + err error + + listByGroupCalls atomic.Int64 + listAllCalls atomic.Int64 +} + +type stickyGatewayCacheHotpathStub struct { + GatewayCache + + stickyID int64 + getCalls atomic.Int64 +} + +func (s *stickyGatewayCacheHotpathStub) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) { + s.getCalls.Add(1) + if s.stickyID > 0 { + return s.stickyID, nil + } + return 0, errors.New("not found") +} + +func (s *stickyGatewayCacheHotpathStub) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error { + return nil +} + +func (s *stickyGatewayCacheHotpathStub) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error { + return nil +} + +func (s *stickyGatewayCacheHotpathStub) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { + return nil +} + +func (s *modelsListAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) { + s.listByGroupCalls.Add(1) + if s.err != nil { + return nil, s.err + } + accounts, ok := s.byGroup[groupID] + if !ok { + return nil, nil + } + out := make([]Account, len(accounts)) + copy(out, accounts) + return out, nil +} + +func (s *modelsListAccountRepoStub) ListSchedulable(ctx context.Context) ([]Account, error) { + s.listAllCalls.Add(1) + if s.err != nil { + return nil, s.err + } + out := make([]Account, len(s.all)) + copy(out, s.all) + return out, nil +} + +func resetGatewayHotpathStatsForTest() { + windowCostPrefetchCacheHitTotal.Store(0) + windowCostPrefetchCacheMissTotal.Store(0) + windowCostPrefetchBatchSQLTotal.Store(0) + windowCostPrefetchFallbackTotal.Store(0) + windowCostPrefetchErrorTotal.Store(0) + + userGroupRateCacheHitTotal.Store(0) + userGroupRateCacheMissTotal.Store(0) + userGroupRateCacheLoadTotal.Store(0) + userGroupRateCacheSFSharedTotal.Store(0) + userGroupRateCacheFallbackTotal.Store(0) + + modelsListCacheHitTotal.Store(0) + modelsListCacheMissTotal.Store(0) + modelsListCacheStoreTotal.Store(0) +} + +func TestGetUserGroupRateMultiplier_UsesCacheAndSingleflight(t *testing.T) { + resetGatewayHotpathStatsForTest() + + rate := 1.7 + unblock := make(chan struct{}) + repo := &userGroupRateRepoHotpathStub{ + rate: &rate, + wait: unblock, + } + svc := &GatewayService{ + userGroupRateRepo: repo, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + UserGroupRateCacheTTLSeconds: 30, + }, + }, + } + + const concurrent = 12 + results := make([]float64, concurrent) + start := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(concurrent) + for i := 0; i < concurrent; i++ { + go func(idx int) { + defer wg.Done() + <-start + results[idx] = svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.2) + }(i) + } + + close(start) + time.Sleep(20 * time.Millisecond) + close(unblock) + wg.Wait() + + for _, got := range results { + require.Equal(t, rate, got) + } + require.Equal(t, int64(1), repo.calls.Load()) + + // 再次读取应命中缓存,不再回源。 + got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.2) + require.Equal(t, rate, got) + require.Equal(t, int64(1), repo.calls.Load()) + + hit, miss, load, sfShared, fallback := GatewayUserGroupRateCacheStats() + require.GreaterOrEqual(t, hit, int64(1)) + require.Equal(t, int64(12), miss) + require.Equal(t, int64(1), load) + require.GreaterOrEqual(t, sfShared, int64(1)) + require.Equal(t, int64(0), fallback) +} + +func TestGetUserGroupRateMultiplier_FallbackOnRepoError(t *testing.T) { + resetGatewayHotpathStatsForTest() + + repo := &userGroupRateRepoHotpathStub{ + err: errors.New("db down"), + } + svc := &GatewayService{ + userGroupRateRepo: repo, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + UserGroupRateCacheTTLSeconds: 30, + }, + }, + } + + got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.25) + require.Equal(t, 1.25, got) + require.Equal(t, int64(1), repo.calls.Load()) + + _, _, _, _, fallback := GatewayUserGroupRateCacheStats() + require.Equal(t, int64(1), fallback) +} + +func TestGetUserGroupRateMultiplier_CacheHitAndNilRepo(t *testing.T) { + resetGatewayHotpathStatsForTest() + + repo := &userGroupRateRepoHotpathStub{ + err: errors.New("should not be called"), + } + svc := &GatewayService{ + userGroupRateRepo: repo, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + } + key := "101:202" + svc.userGroupRateCache.Set(key, 2.3, time.Minute) + + got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.1) + require.Equal(t, 2.3, got) + + hit, miss, load, _, fallback := GatewayUserGroupRateCacheStats() + require.Equal(t, int64(1), hit) + require.Equal(t, int64(0), miss) + require.Equal(t, int64(0), load) + require.Equal(t, int64(0), fallback) + require.Equal(t, int64(0), repo.calls.Load()) + + // 无 repo 时直接返回分组默认倍率 + svc2 := &GatewayService{ + userGroupRateCache: gocache.New(time.Minute, time.Minute), + } + svc2.userGroupRateCache.Set(key, 1.9, time.Minute) + require.Equal(t, 1.9, svc2.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.4)) + require.Equal(t, 1.4, svc2.getUserGroupRateMultiplier(context.Background(), 0, 202, 1.4)) + svc2.userGroupRateCache.Delete(key) + require.Equal(t, 1.4, svc2.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.4)) +} + +func TestWithWindowCostPrefetch_BatchReadAndContextReuse(t *testing.T) { + resetGatewayHotpathStatsForTest() + + windowStart := time.Now().Add(-30 * time.Minute).Truncate(time.Hour) + windowEnd := windowStart.Add(5 * time.Hour) + accounts := []Account{ + { + ID: 1, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{"window_cost_limit": 100.0}, + SessionWindowStart: &windowStart, + SessionWindowEnd: &windowEnd, + }, + { + ID: 2, + Platform: PlatformAnthropic, + Type: AccountTypeSetupToken, + Extra: map[string]any{"window_cost_limit": 100.0}, + SessionWindowStart: &windowStart, + SessionWindowEnd: &windowEnd, + }, + { + ID: 3, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{"window_cost_limit": 100.0}, + }, + } + + cache := &sessionLimitCacheHotpathStub{ + batchData: map[int64]float64{ + 1: 11.0, + }, + } + repo := &usageLogWindowBatchRepoStub{ + batchResult: map[int64]*usagestats.AccountStats{ + 2: {StandardCost: 22.0}, + }, + } + svc := &GatewayService{ + sessionLimitCache: cache, + usageLogRepo: repo, + } + + outCtx := svc.withWindowCostPrefetch(context.Background(), accounts) + require.NotNil(t, outCtx) + + cost1, ok1 := windowCostFromPrefetchContext(outCtx, 1) + require.True(t, ok1) + require.Equal(t, 11.0, cost1) + + cost2, ok2 := windowCostFromPrefetchContext(outCtx, 2) + require.True(t, ok2) + require.Equal(t, 22.0, cost2) + + _, ok3 := windowCostFromPrefetchContext(outCtx, 3) + require.False(t, ok3) + + require.Equal(t, int64(1), repo.batchCalls.Load()) + require.Equal(t, 22.0, cache.setData[2]) + + hit, miss, batchSQL, fallback, errCount := GatewayWindowCostPrefetchStats() + require.Equal(t, int64(1), hit) + require.Equal(t, int64(1), miss) + require.Equal(t, int64(1), batchSQL) + require.Equal(t, int64(0), fallback) + require.Equal(t, int64(0), errCount) +} + +func TestWithWindowCostPrefetch_AllHitNoSQL(t *testing.T) { + resetGatewayHotpathStatsForTest() + + windowStart := time.Now().Add(-30 * time.Minute).Truncate(time.Hour) + windowEnd := windowStart.Add(5 * time.Hour) + accounts := []Account{ + { + ID: 1, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{"window_cost_limit": 100.0}, + SessionWindowStart: &windowStart, + SessionWindowEnd: &windowEnd, + }, + { + ID: 2, + Platform: PlatformAnthropic, + Type: AccountTypeSetupToken, + Extra: map[string]any{"window_cost_limit": 100.0}, + SessionWindowStart: &windowStart, + SessionWindowEnd: &windowEnd, + }, + } + + cache := &sessionLimitCacheHotpathStub{ + batchData: map[int64]float64{ + 1: 11.0, + 2: 22.0, + }, + } + repo := &usageLogWindowBatchRepoStub{} + svc := &GatewayService{ + sessionLimitCache: cache, + usageLogRepo: repo, + } + + outCtx := svc.withWindowCostPrefetch(context.Background(), accounts) + cost1, ok1 := windowCostFromPrefetchContext(outCtx, 1) + cost2, ok2 := windowCostFromPrefetchContext(outCtx, 2) + require.True(t, ok1) + require.True(t, ok2) + require.Equal(t, 11.0, cost1) + require.Equal(t, 22.0, cost2) + require.Equal(t, int64(0), repo.batchCalls.Load()) + require.Equal(t, int64(0), repo.singleCalls.Load()) + + hit, miss, batchSQL, fallback, errCount := GatewayWindowCostPrefetchStats() + require.Equal(t, int64(2), hit) + require.Equal(t, int64(0), miss) + require.Equal(t, int64(0), batchSQL) + require.Equal(t, int64(0), fallback) + require.Equal(t, int64(0), errCount) +} + +func TestWithWindowCostPrefetch_BatchErrorFallbackSingleQuery(t *testing.T) { + resetGatewayHotpathStatsForTest() + + windowStart := time.Now().Add(-30 * time.Minute).Truncate(time.Hour) + windowEnd := windowStart.Add(5 * time.Hour) + accounts := []Account{ + { + ID: 2, + Platform: PlatformAnthropic, + Type: AccountTypeSetupToken, + Extra: map[string]any{"window_cost_limit": 100.0}, + SessionWindowStart: &windowStart, + SessionWindowEnd: &windowEnd, + }, + } + + cache := &sessionLimitCacheHotpathStub{} + repo := &usageLogWindowBatchRepoStub{ + batchErr: errors.New("batch failed"), + singleResult: map[int64]*usagestats.AccountStats{ + 2: {StandardCost: 33.0}, + }, + } + svc := &GatewayService{ + sessionLimitCache: cache, + usageLogRepo: repo, + } + + outCtx := svc.withWindowCostPrefetch(context.Background(), accounts) + cost, ok := windowCostFromPrefetchContext(outCtx, 2) + require.True(t, ok) + require.Equal(t, 33.0, cost) + require.Equal(t, int64(1), repo.batchCalls.Load()) + require.Equal(t, int64(1), repo.singleCalls.Load()) + + _, _, _, fallback, errCount := GatewayWindowCostPrefetchStats() + require.Equal(t, int64(1), fallback) + require.Equal(t, int64(1), errCount) +} + +func TestGetAvailableModels_UsesShortCacheAndSupportsInvalidation(t *testing.T) { + resetGatewayHotpathStatsForTest() + + groupID := int64(9) + repo := &modelsListAccountRepoStub{ + byGroup: map[int64][]Account{ + groupID: { + { + ID: 1, + Platform: PlatformAnthropic, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-3-5-sonnet": "claude-3-5-sonnet", + "claude-3-5-haiku": "claude-3-5-haiku", + }, + }, + }, + { + ID: 2, + Platform: PlatformGemini, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-2.5-pro": "gemini-2.5-pro", + }, + }, + }, + }, + }, + } + + svc := &GatewayService{ + accountRepo: repo, + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + + models1 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic) + require.Equal(t, []string{"claude-3-5-haiku", "claude-3-5-sonnet"}, models1) + require.Equal(t, int64(1), repo.listByGroupCalls.Load()) + + // TTL 内再次请求应命中缓存,不回源。 + models2 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic) + require.Equal(t, models1, models2) + require.Equal(t, int64(1), repo.listByGroupCalls.Load()) + + // 更新仓储数据,但缓存未失效前应继续返回旧值。 + repo.byGroup[groupID] = []Account{ + { + ID: 3, + Platform: PlatformAnthropic, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-3-7-sonnet": "claude-3-7-sonnet", + }, + }, + }, + } + models3 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic) + require.Equal(t, []string{"claude-3-5-haiku", "claude-3-5-sonnet"}, models3) + require.Equal(t, int64(1), repo.listByGroupCalls.Load()) + + svc.InvalidateAvailableModelsCache(&groupID, PlatformAnthropic) + models4 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic) + require.Equal(t, []string{"claude-3-7-sonnet"}, models4) + require.Equal(t, int64(2), repo.listByGroupCalls.Load()) + + hit, miss, store := GatewayModelsListCacheStats() + require.Equal(t, int64(2), hit) + require.Equal(t, int64(2), miss) + require.Equal(t, int64(2), store) +} + +func TestGetAvailableModels_ErrorAndGlobalListBranches(t *testing.T) { + resetGatewayHotpathStatsForTest() + + errRepo := &modelsListAccountRepoStub{ + err: errors.New("db error"), + } + svcErr := &GatewayService{ + accountRepo: errRepo, + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + require.Nil(t, svcErr.GetAvailableModels(context.Background(), nil, "")) + + okRepo := &modelsListAccountRepoStub{ + all: []Account{ + { + ID: 1, + Platform: PlatformAnthropic, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-3-5-sonnet": "claude-3-5-sonnet", + }, + }, + }, + { + ID: 2, + Platform: PlatformGemini, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-2.5-pro": "gemini-2.5-pro", + }, + }, + }, + }, + } + svcOK := &GatewayService{ + accountRepo: okRepo, + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + models := svcOK.GetAvailableModels(context.Background(), nil, "") + require.Equal(t, []string{"claude-3-5-sonnet", "gemini-2.5-pro"}, models) + require.Equal(t, int64(1), okRepo.listAllCalls.Load()) +} + +func TestGatewayHotpathHelpers_CacheTTLAndStickyContext(t *testing.T) { + t.Run("resolve_user_group_rate_cache_ttl", func(t *testing.T) { + require.Equal(t, defaultUserGroupRateCacheTTL, resolveUserGroupRateCacheTTL(nil)) + + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + UserGroupRateCacheTTLSeconds: 45, + }, + } + require.Equal(t, 45*time.Second, resolveUserGroupRateCacheTTL(cfg)) + }) + + t.Run("resolve_models_list_cache_ttl", func(t *testing.T) { + require.Equal(t, defaultModelsListCacheTTL, resolveModelsListCacheTTL(nil)) + + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + ModelsListCacheTTLSeconds: 20, + }, + } + require.Equal(t, 20*time.Second, resolveModelsListCacheTTL(cfg)) + }) + + t.Run("prefetched_sticky_account_id_from_context", func(t *testing.T) { + require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.TODO(), nil)) + require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.Background(), nil)) + + ctx := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, int64(123)) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0)) + require.Equal(t, int64(123), prefetchedStickyAccountIDFromContext(ctx, nil)) + + groupID := int64(9) + ctx2 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, 456) + ctx2 = context.WithValue(ctx2, ctxkey.PrefetchedStickyGroupID, groupID) + require.Equal(t, int64(456), prefetchedStickyAccountIDFromContext(ctx2, &groupID)) + + ctx3 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, "invalid") + ctx3 = context.WithValue(ctx3, ctxkey.PrefetchedStickyGroupID, groupID) + require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx3, &groupID)) + + ctx4 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, int64(789)) + ctx4 = context.WithValue(ctx4, ctxkey.PrefetchedStickyGroupID, int64(10)) + require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx4, &groupID)) + }) + + t.Run("window_cost_from_prefetch_context", func(t *testing.T) { + require.Equal(t, false, func() bool { + _, ok := windowCostFromPrefetchContext(context.TODO(), 0) + return ok + }()) + require.Equal(t, false, func() bool { + _, ok := windowCostFromPrefetchContext(context.Background(), 1) + return ok + }()) + + ctx := context.WithValue(context.Background(), windowCostPrefetchContextKey, map[int64]float64{ + 9: 12.34, + }) + cost, ok := windowCostFromPrefetchContext(ctx, 9) + require.True(t, ok) + require.Equal(t, 12.34, cost) + }) +} + +func TestInvalidateAvailableModelsCache_ByDimensions(t *testing.T) { + svc := &GatewayService{ + modelsListCache: gocache.New(time.Minute, time.Minute), + } + group9 := int64(9) + group10 := int64(10) + svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformAnthropic), []string{"a"}, time.Minute) + svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformGemini), []string{"b"}, time.Minute) + svc.modelsListCache.Set(modelsListCacheKey(&group10, PlatformAnthropic), []string{"c"}, time.Minute) + svc.modelsListCache.Set("invalid-key", []string{"d"}, time.Minute) + + t.Run("invalidate_group_and_platform", func(t *testing.T) { + svc.InvalidateAvailableModelsCache(&group9, PlatformAnthropic) + _, found := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformAnthropic)) + require.False(t, found) + _, stillFound := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformGemini)) + require.True(t, stillFound) + }) + + t.Run("invalidate_group_only", func(t *testing.T) { + svc.InvalidateAvailableModelsCache(&group9, "") + _, foundA := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformAnthropic)) + _, foundB := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformGemini)) + require.False(t, foundA) + require.False(t, foundB) + _, foundOtherGroup := svc.modelsListCache.Get(modelsListCacheKey(&group10, PlatformAnthropic)) + require.True(t, foundOtherGroup) + }) + + t.Run("invalidate_platform_only", func(t *testing.T) { + // 重建数据后仅按 platform 失效 + svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformAnthropic), []string{"a"}, time.Minute) + svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformGemini), []string{"b"}, time.Minute) + svc.modelsListCache.Set(modelsListCacheKey(&group10, PlatformAnthropic), []string{"c"}, time.Minute) + + svc.InvalidateAvailableModelsCache(nil, PlatformAnthropic) + _, found9Anthropic := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformAnthropic)) + _, found10Anthropic := svc.modelsListCache.Get(modelsListCacheKey(&group10, PlatformAnthropic)) + _, found9Gemini := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformGemini)) + require.False(t, found9Anthropic) + require.False(t, found10Anthropic) + require.True(t, found9Gemini) + }) +} + +func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) { + now := time.Now().Add(-time.Minute) + account := Account{ + ID: 88, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 4, + Priority: 1, + LastUsedAt: &now, + } + + repo := stubOpenAIAccountRepo{accounts: []Account{account}} + concurrency := NewConcurrencyService(stubConcurrencyCache{}) + + cfg := &config.Config{ + RunMode: config.RunModeStandard, + Gateway: config.GatewayConfig{ + Scheduling: config.GatewaySchedulingConfig{ + LoadBatchEnabled: true, + StickySessionMaxWaiting: 3, + StickySessionWaitTimeout: time.Second, + FallbackWaitTimeout: time.Second, + FallbackMaxWaiting: 10, + }, + }, + } + + baseCtx := context.WithValue(context.Background(), ctxkey.ForcePlatform, PlatformAnthropic) + + t.Run("without_prefetch_reads_cache_once", func(t *testing.T) { + cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID} + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: concurrency, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + + result, err := svc.SelectAccountWithLoadAwareness(baseCtx, nil, "sess-hash", "", nil, "") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, account.ID, result.Account.ID) + require.Equal(t, int64(1), cache.getCalls.Load()) + }) + + t.Run("with_prefetch_skips_cache_read", func(t *testing.T) { + cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID} + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: concurrency, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + + ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, account.ID) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0)) + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, account.ID, result.Account.ID) + require.Equal(t, int64(0), cache.getCalls.Load()) + }) + + t.Run("with_prefetch_group_mismatch_reads_cache", func(t *testing.T) { + cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID} + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: concurrency, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + + ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, int64(999)) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(77)) + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, account.ID, result.Account.ID) + require.Equal(t, int64(1), cache.getCalls.Load()) + }) +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 09fda60e..1cb3c61e 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -77,6 +77,11 @@ func (m *mockAccountRepoForPlatform) Create(ctx context.Context, account *Accoun func (m *mockAccountRepoForPlatform) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) { return nil, nil } + +func (m *mockAccountRepoForPlatform) FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) { + return nil, nil +} + func (m *mockAccountRepoForPlatform) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { return nil, nil } @@ -142,6 +147,12 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByPlatforms(ctx context.Cont func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) { return m.ListSchedulableByPlatforms(ctx, platforms) } +func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + return m.ListSchedulableByPlatform(ctx, platform) +} +func (m *mockAccountRepoForPlatform) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + return m.ListSchedulableByPlatforms(ctx, platforms) +} func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } @@ -890,6 +901,55 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiPreferOAuth(t *t require.Equal(t, int64(2), acc.ID) } +func TestGatewayService_SelectAccountForModelWithPlatform_GeminiAPIKeyModelMappingFilter(t *testing.T) { + ctx := context.Background() + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + { + ID: 1, + Platform: PlatformGemini, + Type: AccountTypeAPIKey, + Priority: 1, + Status: StatusActive, + Schedulable: true, + Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"}}, + }, + { + ID: 2, + Platform: PlatformGemini, + Type: AccountTypeAPIKey, + Priority: 2, + Status: StatusActive, + Schedulable: true, + Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-flash": "gemini-2.5-flash"}}, + }, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForPlatform{} + + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: testConfig(), + } + + acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-flash", nil, PlatformGemini) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID, "应过滤不支持请求模型的 APIKey 账号") + + acc, err = svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-3-pro-preview", nil, PlatformGemini) + require.Error(t, err) + require.Nil(t, acc) + require.Contains(t, err.Error(), "supporting model") +} + func TestGatewayService_SelectAccountForModelWithPlatform_StickyInGroup(t *testing.T) { ctx := context.Background() groupID := int64(50) @@ -1065,6 +1125,36 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) { model: "claude-3-5-sonnet-20241022", expected: true, }, + { + name: "Gemini平台-无映射配置-支持所有模型", + account: &Account{Platform: PlatformGemini, Type: AccountTypeAPIKey}, + model: "gemini-2.5-flash", + expected: true, + }, + { + name: "Gemini平台-有映射配置-只支持配置的模型", + account: &Account{ + Platform: PlatformGemini, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"}, + }, + }, + model: "gemini-2.5-flash", + expected: false, + }, + { + name: "Gemini平台-有映射配置-支持配置的模型", + account: &Account{ + Platform: PlatformGemini, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "model_mapping": map[string]any{"gemini-2.5-pro": "gemini-2.5-pro"}, + }, + }, + model: "gemini-2.5-pro", + expected: true, + }, } for _, tt := range tests { @@ -1808,6 +1898,14 @@ func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accoun return 0, nil } +func (m *mockConcurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { + result := make(map[int64]int, len(accountIDs)) + for _, accountID := range accountIDs { + result[accountID] = 0 + } + return result, nil +} + func (m *mockConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { return true, nil } diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index 743dd738..b546fe85 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -5,9 +5,28 @@ import ( "encoding/json" "fmt" "math" + "unsafe" "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +var ( + // 这些字节模式用于 fast-path 判断,避免每次 []byte("...") 产生临时分配。 + patternTypeThinking = []byte(`"type":"thinking"`) + patternTypeThinkingSpaced = []byte(`"type": "thinking"`) + patternTypeRedactedThinking = []byte(`"type":"redacted_thinking"`) + patternTypeRedactedSpaced = []byte(`"type": "redacted_thinking"`) + + patternThinkingField = []byte(`"thinking":`) + patternThinkingFieldSpaced = []byte(`"thinking" :`) + + patternEmptyContent = []byte(`"content":[]`) + patternEmptyContentSpaced = []byte(`"content": []`) + patternEmptyContentSp1 = []byte(`"content" : []`) + patternEmptyContentSp2 = []byte(`"content" :[]`) ) // SessionContext 粘性会话上下文,用于区分不同来源的请求。 @@ -42,119 +61,137 @@ type ParsedRequest struct { ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名) MaxTokens int // max_tokens 值(用于探测请求拦截) SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变) + + // OnUpstreamAccepted 上游接受请求后立即调用(用于提前释放串行锁) + // 流式请求在收到 2xx 响应头后调用,避免持锁等流完成 + OnUpstreamAccepted func() } // ParseGatewayRequest 解析网关请求体并返回结构化结果。 // protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini), // 不同协议使用不同的 system/messages 字段名。 func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) { - var req map[string]any - if err := json.Unmarshal(body, &req); err != nil { - return nil, err + // 保持与旧实现一致:请求体必须是合法 JSON。 + // 注意:gjson.GetBytes 对非法 JSON 不会报错,因此需要显式校验。 + if !gjson.ValidBytes(body) { + return nil, fmt.Errorf("invalid json") } + // 性能: + // - gjson.GetBytes 会把匹配的 Raw/Str 安全复制成 string(对于巨大 messages 会产生额外拷贝)。 + // - 这里将 body 通过 unsafe 零拷贝视为 string,仅在本函数内使用,且 body 不会被修改。 + jsonStr := *(*string)(unsafe.Pointer(&body)) + parsed := &ParsedRequest{ Body: body, } - if rawModel, exists := req["model"]; exists { - model, ok := rawModel.(string) - if !ok { + // --- gjson 提取简单字段(避免完整 Unmarshal) --- + + // model: 需要严格类型校验,非 string 返回错误 + modelResult := gjson.Get(jsonStr, "model") + if modelResult.Exists() { + if modelResult.Type != gjson.String { return nil, fmt.Errorf("invalid model field type") } - parsed.Model = model + parsed.Model = modelResult.String() } - if rawStream, exists := req["stream"]; exists { - stream, ok := rawStream.(bool) - if !ok { + + // stream: 需要严格类型校验,非 bool 返回错误 + streamResult := gjson.Get(jsonStr, "stream") + if streamResult.Exists() { + if streamResult.Type != gjson.True && streamResult.Type != gjson.False { return nil, fmt.Errorf("invalid stream field type") } - parsed.Stream = stream + parsed.Stream = streamResult.Bool() } - if metadata, ok := req["metadata"].(map[string]any); ok { - if userID, ok := metadata["user_id"].(string); ok { - parsed.MetadataUserID = userID + + // metadata.user_id: 直接路径提取,不需要严格类型校验 + parsed.MetadataUserID = gjson.Get(jsonStr, "metadata.user_id").String() + + // thinking.type: enabled/adaptive 都视为开启 + thinkingType := gjson.Get(jsonStr, "thinking.type").String() + if thinkingType == "enabled" || thinkingType == "adaptive" { + parsed.ThinkingEnabled = true + } + + // max_tokens: 仅接受整数值 + maxTokensResult := gjson.Get(jsonStr, "max_tokens") + if maxTokensResult.Exists() && maxTokensResult.Type == gjson.Number { + f := maxTokensResult.Float() + if !math.IsNaN(f) && !math.IsInf(f, 0) && f == math.Trunc(f) && + f <= float64(math.MaxInt) && f >= float64(math.MinInt) { + parsed.MaxTokens = int(f) } } + // --- system/messages 提取 --- + // 避免把整个 body Unmarshal 到 map(会产生大量 map/接口分配)。 + // 使用 gjson 抽取目标字段的 Raw,再对该子树进行 Unmarshal。 + switch protocol { case domain.PlatformGemini: // Gemini 原生格式: systemInstruction.parts / contents - if sysInst, ok := req["systemInstruction"].(map[string]any); ok { - if parts, ok := sysInst["parts"].([]any); ok { - parsed.System = parts + if sysParts := gjson.Get(jsonStr, "systemInstruction.parts"); sysParts.Exists() && sysParts.IsArray() { + var parts []any + if err := json.Unmarshal(sliceRawFromBody(body, sysParts), &parts); err != nil { + return nil, err } + parsed.System = parts } - if contents, ok := req["contents"].([]any); ok { - parsed.Messages = contents + + if contents := gjson.Get(jsonStr, "contents"); contents.Exists() && contents.IsArray() { + var msgs []any + if err := json.Unmarshal(sliceRawFromBody(body, contents), &msgs); err != nil { + return nil, err + } + parsed.Messages = msgs } default: // Anthropic / OpenAI 格式: system / messages // system 字段只要存在就视为显式提供(即使为 null), // 以避免客户端传 null 时被默认 system 误注入。 - if system, ok := req["system"]; ok { + if sys := gjson.Get(jsonStr, "system"); sys.Exists() { parsed.HasSystem = true - parsed.System = system + switch sys.Type { + case gjson.Null: + parsed.System = nil + case gjson.String: + // 与 encoding/json 的 Unmarshal 行为一致:返回解码后的字符串。 + parsed.System = sys.String() + default: + var system any + if err := json.Unmarshal(sliceRawFromBody(body, sys), &system); err != nil { + return nil, err + } + parsed.System = system + } } - if messages, ok := req["messages"].([]any); ok { + + if msgs := gjson.Get(jsonStr, "messages"); msgs.Exists() && msgs.IsArray() { + var messages []any + if err := json.Unmarshal(sliceRawFromBody(body, msgs), &messages); err != nil { + return nil, err + } parsed.Messages = messages } } - // thinking: {type: "enabled" | "adaptive"} - if rawThinking, ok := req["thinking"].(map[string]any); ok { - if t, ok := rawThinking["type"].(string); ok && (t == "enabled" || t == "adaptive") { - parsed.ThinkingEnabled = true - } - } - - // max_tokens - if rawMaxTokens, exists := req["max_tokens"]; exists { - if maxTokens, ok := parseIntegralNumber(rawMaxTokens); ok { - parsed.MaxTokens = maxTokens - } - } - return parsed, nil } -// parseIntegralNumber 将 JSON 解码后的数字安全转换为 int。 -// 仅接受“整数值”的输入,小数/NaN/Inf/越界值都会返回 false。 -func parseIntegralNumber(raw any) (int, bool) { - switch v := raw.(type) { - case float64: - if math.IsNaN(v) || math.IsInf(v, 0) || v != math.Trunc(v) { - return 0, false +// sliceRawFromBody 返回 Result.Raw 对应的原始字节切片。 +// 优先使用 Result.Index 直接从 body 切片,避免对大字段(如 messages)产生额外拷贝。 +// 当 Index 不可用时,退化为复制(理论上极少发生)。 +func sliceRawFromBody(body []byte, r gjson.Result) []byte { + if r.Index > 0 { + end := r.Index + len(r.Raw) + if end <= len(body) { + return body[r.Index:end] } - if v > float64(math.MaxInt) || v < float64(math.MinInt) { - return 0, false - } - return int(v), true - case int: - return v, true - case int8: - return int(v), true - case int16: - return int(v), true - case int32: - return int(v), true - case int64: - if v > int64(math.MaxInt) || v < int64(math.MinInt) { - return 0, false - } - return int(v), true - case json.Number: - i64, err := v.Int64() - if err != nil { - return 0, false - } - if i64 > int64(math.MaxInt) || i64 < int64(math.MinInt) { - return 0, false - } - return int(i64), true - default: - return 0, false } + // fallback: 不影响正确性,但会产生一次拷贝 + return []byte(r.Raw) } // FilterThinkingBlocks removes thinking blocks from request body @@ -184,49 +221,63 @@ func FilterThinkingBlocks(body []byte) []byte { // - Remove `redacted_thinking` blocks (cannot be converted to text). // - Ensure no message ends up with empty content. func FilterThinkingBlocksForRetry(body []byte) []byte { - hasThinkingContent := bytes.Contains(body, []byte(`"type":"thinking"`)) || - bytes.Contains(body, []byte(`"type": "thinking"`)) || - bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) || - bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) || - bytes.Contains(body, []byte(`"thinking":`)) || - bytes.Contains(body, []byte(`"thinking" :`)) + hasThinkingContent := bytes.Contains(body, patternTypeThinking) || + bytes.Contains(body, patternTypeThinkingSpaced) || + bytes.Contains(body, patternTypeRedactedThinking) || + bytes.Contains(body, patternTypeRedactedSpaced) || + bytes.Contains(body, patternThinkingField) || + bytes.Contains(body, patternThinkingFieldSpaced) // Also check for empty content arrays that need fixing. // Note: This is a heuristic check; the actual empty content handling is done below. - hasEmptyContent := bytes.Contains(body, []byte(`"content":[]`)) || - bytes.Contains(body, []byte(`"content": []`)) || - bytes.Contains(body, []byte(`"content" : []`)) || - bytes.Contains(body, []byte(`"content" :[]`)) + hasEmptyContent := bytes.Contains(body, patternEmptyContent) || + bytes.Contains(body, patternEmptyContentSpaced) || + bytes.Contains(body, patternEmptyContentSp1) || + bytes.Contains(body, patternEmptyContentSp2) // Fast path: nothing to process if !hasThinkingContent && !hasEmptyContent { return body } - var req map[string]any - if err := json.Unmarshal(body, &req); err != nil { + // 尽量避免把整个 body Unmarshal 成 map(会产生大量 map/接口分配)。 + // 这里先用 gjson 把 messages 子树摘出来,后续只对 messages 做 Unmarshal/Marshal。 + jsonStr := *(*string)(unsafe.Pointer(&body)) + msgsRes := gjson.Get(jsonStr, "messages") + if !msgsRes.Exists() || !msgsRes.IsArray() { + return body + } + + // Fast path:只需要删除顶层 thinking,不需要改 messages。 + // 注意:patternThinkingField 可能来自嵌套字段(如 tool_use.input.thinking),因此必须用 gjson 判断顶层字段是否存在。 + containsThinkingBlocks := bytes.Contains(body, patternTypeThinking) || + bytes.Contains(body, patternTypeThinkingSpaced) || + bytes.Contains(body, patternTypeRedactedThinking) || + bytes.Contains(body, patternTypeRedactedSpaced) || + bytes.Contains(body, patternThinkingFieldSpaced) + if !hasEmptyContent && !containsThinkingBlocks { + if topThinking := gjson.Get(jsonStr, "thinking"); topThinking.Exists() { + if out, err := sjson.DeleteBytes(body, "thinking"); err == nil { + return out + } + return body + } + return body + } + + var messages []any + if err := json.Unmarshal(sliceRawFromBody(body, msgsRes), &messages); err != nil { return body } modified := false - messages, ok := req["messages"].([]any) - if !ok { - return body - } - // Disable top-level thinking mode for retry to avoid structural/signature constraints upstream. - if _, exists := req["thinking"]; exists { - delete(req, "thinking") - modified = true - } + deleteTopLevelThinking := gjson.Get(jsonStr, "thinking").Exists() - newMessages := make([]any, 0, len(messages)) - - for _, msg := range messages { - msgMap, ok := msg.(map[string]any) + for i := 0; i < len(messages); i++ { + msgMap, ok := messages[i].(map[string]any) if !ok { - newMessages = append(newMessages, msg) continue } @@ -234,17 +285,30 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { content, ok := msgMap["content"].([]any) if !ok { // String content or other format - keep as is - newMessages = append(newMessages, msg) continue } - newContent := make([]any, 0, len(content)) + // 延迟分配:只有检测到需要修改的块,才构建新 slice。 + var newContent []any modifiedThisMsg := false - for _, block := range content { + ensureNewContent := func(prefixLen int) { + if newContent != nil { + return + } + newContent = make([]any, 0, len(content)) + if prefixLen > 0 { + newContent = append(newContent, content[:prefixLen]...) + } + } + + for bi := 0; bi < len(content); bi++ { + block := content[bi] blockMap, ok := block.(map[string]any) if !ok { - newContent = append(newContent, block) + if newContent != nil { + newContent = append(newContent, block) + } continue } @@ -254,17 +318,15 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { switch blockType { case "thinking": modifiedThisMsg = true + ensureNewContent(bi) thinkingText, _ := blockMap["thinking"].(string) - if thinkingText == "" { - continue + if thinkingText != "" { + newContent = append(newContent, map[string]any{"type": "text", "text": thinkingText}) } - newContent = append(newContent, map[string]any{ - "type": "text", - "text": thinkingText, - }) continue case "redacted_thinking": modifiedThisMsg = true + ensureNewContent(bi) continue } @@ -272,6 +334,7 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { if blockType == "" { if rawThinking, hasThinking := blockMap["thinking"]; hasThinking { modifiedThisMsg = true + ensureNewContent(bi) switch v := rawThinking.(type) { case string: if v != "" { @@ -286,40 +349,64 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { } } - newContent = append(newContent, block) + if newContent != nil { + newContent = append(newContent, block) + } } // Handle empty content: either from filtering or originally empty + if newContent == nil { + if len(content) == 0 { + modified = true + placeholder := "(content removed)" + if role == "assistant" { + placeholder = "(assistant content removed)" + } + msgMap["content"] = []any{map[string]any{"type": "text", "text": placeholder}} + } + continue + } + if len(newContent) == 0 { modified = true placeholder := "(content removed)" if role == "assistant" { placeholder = "(assistant content removed)" } - newContent = append(newContent, map[string]any{ - "type": "text", - "text": placeholder, - }) - msgMap["content"] = newContent - } else if modifiedThisMsg { + msgMap["content"] = []any{map[string]any{"type": "text", "text": placeholder}} + continue + } + + if modifiedThisMsg { modified = true msgMap["content"] = newContent } - newMessages = append(newMessages, msgMap) } - if modified { - req["messages"] = newMessages - } else { + if !modified && !deleteTopLevelThinking { // Avoid rewriting JSON when no changes are needed. return body } - newBody, err := json.Marshal(req) - if err != nil { - return body + out := body + if deleteTopLevelThinking { + if b, err := sjson.DeleteBytes(out, "thinking"); err == nil { + out = b + } else { + return body + } } - return newBody + if modified { + msgsBytes, err := json.Marshal(messages) + if err != nil { + return body + } + out, err = sjson.SetRawBytes(out, "messages", msgsBytes) + if err != nil { + return body + } + } + return out } // FilterSignatureSensitiveBlocksForRetry is a stronger retry filter for cases where upstream errors indicate diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go index 5b85e752..2a9b4017 100644 --- a/backend/internal/service/gateway_request_test.go +++ b/backend/internal/service/gateway_request_test.go @@ -1,7 +1,11 @@ +//go:build unit + package service import ( "encoding/json" + "fmt" + "strings" "testing" "github.com/Wei-Shaw/sub2api/internal/domain" @@ -434,3 +438,341 @@ func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) { require.Contains(t, content0["text"], "tool_use") require.Contains(t, content1["text"], "tool_result") } + +// ============ Group 7: ParseGatewayRequest 补充单元测试 ============ + +// Task 7.1 — 类型校验边界测试 +func TestParseGatewayRequest_TypeValidation(t *testing.T) { + tests := []struct { + name string + body string + wantErr bool + errSubstr string // 期望的错误信息子串(为空则不检查) + }{ + { + name: "model 为 int", + body: `{"model":123}`, + wantErr: true, + errSubstr: "invalid model field type", + }, + { + name: "model 为 array", + body: `{"model":[]}`, + wantErr: true, + errSubstr: "invalid model field type", + }, + { + name: "model 为 bool", + body: `{"model":true}`, + wantErr: true, + errSubstr: "invalid model field type", + }, + { + name: "model 为 null — gjson Null 类型触发类型校验错误", + body: `{"model":null}`, + wantErr: true, // gjson: Exists()=true, Type=Null != String → 返回错误 + errSubstr: "invalid model field type", + }, + { + name: "stream 为 string", + body: `{"stream":"true"}`, + wantErr: true, + errSubstr: "invalid stream field type", + }, + { + name: "stream 为 int", + body: `{"stream":1}`, + wantErr: true, + errSubstr: "invalid stream field type", + }, + { + name: "stream 为 null — gjson Null 类型触发类型校验错误", + body: `{"stream":null}`, + wantErr: true, // gjson: Exists()=true, Type=Null != True && != False → 返回错误 + errSubstr: "invalid stream field type", + }, + { + name: "model 为 object", + body: `{"model":{}}`, + wantErr: true, + errSubstr: "invalid model field type", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ParseGatewayRequest([]byte(tt.body), "") + if tt.wantErr { + require.Error(t, err) + if tt.errSubstr != "" { + require.Contains(t, err.Error(), tt.errSubstr) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// Task 7.2 — 可选字段缺失测试 +func TestParseGatewayRequest_OptionalFieldsMissing(t *testing.T) { + tests := []struct { + name string + body string + wantModel string + wantStream bool + wantMetadataUID string + wantHasSystem bool + wantThinking bool + wantMaxTokens int + wantMessagesNil bool + wantMessagesLen int + }{ + { + name: "完全空 JSON — 所有字段零值", + body: `{}`, + wantModel: "", + wantStream: false, + wantMetadataUID: "", + wantHasSystem: false, + wantThinking: false, + wantMaxTokens: 0, + wantMessagesNil: true, + }, + { + name: "metadata 无 user_id", + body: `{"model":"test"}`, + wantModel: "test", + wantMetadataUID: "", + wantHasSystem: false, + wantThinking: false, + }, + { + name: "thinking 非 enabled(type=disabled)", + body: `{"model":"test","thinking":{"type":"disabled"}}`, + wantModel: "test", + wantThinking: false, + }, + { + name: "thinking 字段缺失", + body: `{"model":"test"}`, + wantModel: "test", + wantThinking: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parsed, err := ParseGatewayRequest([]byte(tt.body), "") + require.NoError(t, err) + + require.Equal(t, tt.wantModel, parsed.Model) + require.Equal(t, tt.wantStream, parsed.Stream) + require.Equal(t, tt.wantMetadataUID, parsed.MetadataUserID) + require.Equal(t, tt.wantHasSystem, parsed.HasSystem) + require.Equal(t, tt.wantThinking, parsed.ThinkingEnabled) + require.Equal(t, tt.wantMaxTokens, parsed.MaxTokens) + + if tt.wantMessagesNil { + require.Nil(t, parsed.Messages) + } + if tt.wantMessagesLen > 0 { + require.Len(t, parsed.Messages, tt.wantMessagesLen) + } + }) + } +} + +// Task 7.3 — Gemini 协议分支测试 +// 已有测试覆盖: +// - TestParseGatewayRequest_GeminiSystemInstruction: 正常 systemInstruction+contents +// - TestParseGatewayRequest_GeminiNoContents: 缺失 contents +// - TestParseGatewayRequest_GeminiContents: 正常 contents(无 systemInstruction) +// 因此跳过。 + +// Task 7.4 — max_tokens 边界测试 +func TestParseGatewayRequest_MaxTokensBoundary(t *testing.T) { + tests := []struct { + name string + body string + wantMaxTokens int + wantErr bool + }{ + { + name: "正常整数", + body: `{"max_tokens":1024}`, + wantMaxTokens: 1024, + }, + { + name: "浮点数(非整数)被忽略", + body: `{"max_tokens":10.5}`, + wantMaxTokens: 0, + }, + { + name: "负整数可以通过", + body: `{"max_tokens":-1}`, + wantMaxTokens: -1, + }, + { + name: "超大值不 panic", + body: `{"max_tokens":9999999999999999}`, + wantMaxTokens: 10000000000000000, // float64 精度导致 9999999999999999 → 1e16 + }, + { + name: "null 值被忽略", + body: `{"max_tokens":null}`, + wantMaxTokens: 0, // gjson Type=Null != Number → 条件不满足,跳过 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parsed, err := ParseGatewayRequest([]byte(tt.body), "") + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.wantMaxTokens, parsed.MaxTokens) + }) + } +} + +// ============ Task 7.5: Benchmark 测试 ============ + +// parseGatewayRequestOld 是基于完整 json.Unmarshal 的旧实现,用于 benchmark 对比基线。 +// 核心路径:先 Unmarshal 到 map[string]any,再逐字段提取。 +func parseGatewayRequestOld(body []byte, protocol string) (*ParsedRequest, error) { + parsed := &ParsedRequest{ + Body: body, + } + + var req map[string]any + if err := json.Unmarshal(body, &req); err != nil { + return nil, err + } + + // model + if raw, ok := req["model"]; ok { + s, ok := raw.(string) + if !ok { + return nil, fmt.Errorf("invalid model field type") + } + parsed.Model = s + } + + // stream + if raw, ok := req["stream"]; ok { + b, ok := raw.(bool) + if !ok { + return nil, fmt.Errorf("invalid stream field type") + } + parsed.Stream = b + } + + // metadata.user_id + if meta, ok := req["metadata"].(map[string]any); ok { + if uid, ok := meta["user_id"].(string); ok { + parsed.MetadataUserID = uid + } + } + + // thinking.type + if thinking, ok := req["thinking"].(map[string]any); ok { + if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" { + parsed.ThinkingEnabled = true + } + } + + // max_tokens + if raw, ok := req["max_tokens"]; ok { + if n, ok := parseIntegralNumber(raw); ok { + parsed.MaxTokens = n + } + } + + // system / messages(按协议分支) + switch protocol { + case domain.PlatformGemini: + if sysInst, ok := req["systemInstruction"].(map[string]any); ok { + if parts, ok := sysInst["parts"].([]any); ok { + parsed.System = parts + } + } + if contents, ok := req["contents"].([]any); ok { + parsed.Messages = contents + } + default: + if system, ok := req["system"]; ok { + parsed.HasSystem = true + parsed.System = system + } + if messages, ok := req["messages"].([]any); ok { + parsed.Messages = messages + } + } + + return parsed, nil +} + +// buildSmallJSON 构建 ~500B 的小型测试 JSON +func buildSmallJSON() []byte { + return []byte(`{"model":"claude-sonnet-4-5","stream":true,"max_tokens":4096,"metadata":{"user_id":"user-abc123"},"thinking":{"type":"enabled","budget_tokens":2048},"system":"You are a helpful assistant.","messages":[{"role":"user","content":"What is the meaning of life?"},{"role":"assistant","content":"The meaning of life is a philosophical question."},{"role":"user","content":"Can you elaborate?"}]}`) +} + +// buildLargeJSON 构建 ~50KB 的大型测试 JSON(大量 messages) +func buildLargeJSON() []byte { + var b strings.Builder + b.WriteString(`{"model":"claude-sonnet-4-5","stream":true,"max_tokens":8192,"metadata":{"user_id":"user-xyz789"},"system":[{"type":"text","text":"You are a detailed assistant.","cache_control":{"type":"ephemeral"}}],"messages":[`) + + msgCount := 200 + for i := 0; i < msgCount; i++ { + if i > 0 { + b.WriteByte(',') + } + if i%2 == 0 { + b.WriteString(fmt.Sprintf(`{"role":"user","content":"This is user message number %d with some extra padding text to make the message reasonably long for benchmarking purposes. Lorem ipsum dolor sit amet."}`, i)) + } else { + b.WriteString(fmt.Sprintf(`{"role":"assistant","content":[{"type":"text","text":"This is assistant response number %d. I will provide a detailed answer with multiple sentences to simulate real conversation content for benchmark testing."}]}`, i)) + } + } + + b.WriteString(`]}`) + return []byte(b.String()) +} + +func BenchmarkParseGatewayRequest_Old_Small(b *testing.B) { + data := buildSmallJSON() + b.SetBytes(int64(len(data))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = parseGatewayRequestOld(data, "") + } +} + +func BenchmarkParseGatewayRequest_New_Small(b *testing.B) { + data := buildSmallJSON() + b.SetBytes(int64(len(data))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ParseGatewayRequest(data, "") + } +} + +func BenchmarkParseGatewayRequest_Old_Large(b *testing.B) { + data := buildLargeJSON() + b.SetBytes(int64(len(data))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = parseGatewayRequestOld(data, "") + } +} + +func BenchmarkParseGatewayRequest_New_Large(b *testing.B) { + data := buildLargeJSON() + b.SetBytes(int64(len(data))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ParseGatewayRequest(data, "") + } +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 56af4610..02f9a6a3 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "io" - "log" "log/slog" mathrand "math/rand" "net/http" @@ -24,12 +23,16 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/cespare/xxhash/v2" "github.com/google/uuid" + gocache "github.com/patrickmn/go-cache" "github.com/tidwall/gjson" "github.com/tidwall/sjson" + "golang.org/x/sync/singleflight" "github.com/gin-gonic/gin" ) @@ -44,6 +47,9 @@ const ( // separator between system blocks, we add "\n\n" at concatenation time. claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude." maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量 + + defaultUserGroupRateCacheTTL = 30 * time.Second + defaultModelsListCacheTTL = 15 * time.Second ) const ( @@ -62,6 +68,53 @@ type accountWithLoad struct { var ForceCacheBillingContextKey = forceCacheBillingKeyType{} +var ( + windowCostPrefetchCacheHitTotal atomic.Int64 + windowCostPrefetchCacheMissTotal atomic.Int64 + windowCostPrefetchBatchSQLTotal atomic.Int64 + windowCostPrefetchFallbackTotal atomic.Int64 + windowCostPrefetchErrorTotal atomic.Int64 + + userGroupRateCacheHitTotal atomic.Int64 + userGroupRateCacheMissTotal atomic.Int64 + userGroupRateCacheLoadTotal atomic.Int64 + userGroupRateCacheSFSharedTotal atomic.Int64 + userGroupRateCacheFallbackTotal atomic.Int64 + + modelsListCacheHitTotal atomic.Int64 + modelsListCacheMissTotal atomic.Int64 + modelsListCacheStoreTotal atomic.Int64 +) + +func GatewayWindowCostPrefetchStats() (cacheHit, cacheMiss, batchSQL, fallback, errCount int64) { + return windowCostPrefetchCacheHitTotal.Load(), + windowCostPrefetchCacheMissTotal.Load(), + windowCostPrefetchBatchSQLTotal.Load(), + windowCostPrefetchFallbackTotal.Load(), + windowCostPrefetchErrorTotal.Load() +} + +func GatewayUserGroupRateCacheStats() (cacheHit, cacheMiss, load, singleflightShared, fallback int64) { + return userGroupRateCacheHitTotal.Load(), + userGroupRateCacheMissTotal.Load(), + userGroupRateCacheLoadTotal.Load(), + userGroupRateCacheSFSharedTotal.Load(), + userGroupRateCacheFallbackTotal.Load() +} + +func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) { + return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load() +} + +func cloneStringSlice(src []string) []string { + if len(src) == 0 { + return nil + } + dst := make([]string, len(src)) + copy(dst, src) + return dst +} + // IsForceCacheBilling 检查是否启用强制缓存计费 func IsForceCacheBilling(ctx context.Context) bool { v, _ := ctx.Value(ForceCacheBillingContextKey).(bool) @@ -74,13 +127,26 @@ func WithForceCacheBilling(ctx context.Context) context.Context { } func (s *GatewayService) debugModelRoutingEnabled() bool { - v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING"))) - return v == "1" || v == "true" || v == "yes" || v == "on" + if s == nil { + return false + } + return s.debugModelRouting.Load() } func (s *GatewayService) debugClaudeMimicEnabled() bool { - v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC"))) - return v == "1" || v == "true" || v == "yes" || v == "on" + if s == nil { + return false + } + return s.debugClaudeMimic.Load() +} + +func parseDebugEnvBool(raw string) bool { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "1", "true", "yes", "on": + return true + default: + return false + } } func shortSessionHash(sessionHash string) string { @@ -213,7 +279,7 @@ func logClaudeMimicDebug(req *http.Request, body []byte, account *Account, token if line == "" { return } - log.Printf("[ClaudeMimicDebug] %s", line) + logger.LegacyPrintf("service.gateway", "[ClaudeMimicDebug] %s", line) } func isClaudeCodeCredentialScopeError(msg string) bool { @@ -302,6 +368,39 @@ func derefGroupID(groupID *int64) int64 { return *groupID } +func resolveUserGroupRateCacheTTL(cfg *config.Config) time.Duration { + if cfg == nil || cfg.Gateway.UserGroupRateCacheTTLSeconds <= 0 { + return defaultUserGroupRateCacheTTL + } + return time.Duration(cfg.Gateway.UserGroupRateCacheTTLSeconds) * time.Second +} + +func resolveModelsListCacheTTL(cfg *config.Config) time.Duration { + if cfg == nil || cfg.Gateway.ModelsListCacheTTLSeconds <= 0 { + return defaultModelsListCacheTTL + } + return time.Duration(cfg.Gateway.ModelsListCacheTTLSeconds) * time.Second +} + +func modelsListCacheKey(groupID *int64, platform string) string { + return fmt.Sprintf("%d|%s", derefGroupID(groupID), strings.TrimSpace(platform)) +} + +func prefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) { + return PrefetchedStickyGroupIDFromContext(ctx) +} + +func prefetchedStickyAccountIDFromContext(ctx context.Context, groupID *int64) int64 { + prefetchedGroupID, ok := prefetchedStickyGroupIDFromContext(ctx) + if !ok || prefetchedGroupID != derefGroupID(groupID) { + return 0 + } + if accountID, ok := PrefetchedStickyAccountIDFromContext(ctx); ok && accountID > 0 { + return accountID + } + return 0 +} + // shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。 // 当账号状态为错误、禁用、不可调度、处于临时不可调度期间, // 或请求的模型处于限流状态时,返回 true。 @@ -349,6 +448,8 @@ type ClaudeUsage struct { OutputTokens int `json:"output_tokens"` CacheCreationInputTokens int `json:"cache_creation_input_tokens"` CacheReadInputTokens int `json:"cache_read_input_tokens"` + CacheCreation5mTokens int // 5分钟缓存创建token(来自嵌套 cache_creation 对象) + CacheCreation1hTokens int // 1小时缓存创建token(来自嵌套 cache_creation 对象) } // ForwardResult 转发结果 @@ -361,17 +462,22 @@ type ForwardResult struct { FirstTokenMs *int // 首字时间(流式请求) ClientDisconnect bool // 客户端是否在流式传输过程中断开 - // 图片生成计费字段(仅 gemini-3-pro-image 使用) + // 图片生成计费字段(图片生成模型使用) ImageCount int // 生成的图片数量 ImageSize string // 图片尺寸 "1K", "2K", "4K" + + // Sora 媒体字段 + MediaType string // image / video / prompt + MediaURL string // 生成后的媒体地址(可选) } // UpstreamFailoverError indicates an upstream error that should trigger account failover. type UpstreamFailoverError struct { StatusCode int - ResponseBody []byte // 上游响应体,用于错误透传规则匹配 - ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true - RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换 + ResponseBody []byte // 上游响应体,用于错误透传规则匹配 + ResponseHeaders http.Header // 上游响应头,用于透传 cf-ray/cf-mitigated/content-type 等诊断信息 + ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true + RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换 } func (e *UpstreamFailoverError) Error() string { @@ -395,25 +501,33 @@ func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accou // GatewayService handles API gateway operations type GatewayService struct { - accountRepo AccountRepository - groupRepo GroupRepository - usageLogRepo UsageLogRepository - userRepo UserRepository - userSubRepo UserSubscriptionRepository - userGroupRateRepo UserGroupRateRepository - cache GatewayCache - digestStore *DigestSessionStore - cfg *config.Config - schedulerSnapshot *SchedulerSnapshotService - billingService *BillingService - rateLimitService *RateLimitService - billingCacheService *BillingCacheService - identityService *IdentityService - httpUpstream HTTPUpstream - deferredService *DeferredService - concurrencyService *ConcurrencyService - claudeTokenProvider *ClaudeTokenProvider - sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) + accountRepo AccountRepository + groupRepo GroupRepository + usageLogRepo UsageLogRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + userGroupRateRepo UserGroupRateRepository + cache GatewayCache + digestStore *DigestSessionStore + cfg *config.Config + schedulerSnapshot *SchedulerSnapshotService + billingService *BillingService + rateLimitService *RateLimitService + billingCacheService *BillingCacheService + identityService *IdentityService + httpUpstream HTTPUpstream + deferredService *DeferredService + concurrencyService *ConcurrencyService + claudeTokenProvider *ClaudeTokenProvider + sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) + rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken) + userGroupRateCache *gocache.Cache + userGroupRateSF singleflight.Group + modelsListCache *gocache.Cache + modelsListCacheTTL time.Duration + responseHeaderFilter *responseheaders.CompiledHeaderFilter + debugModelRouting atomic.Bool + debugClaudeMimic atomic.Bool } // NewGatewayService creates a new GatewayService @@ -436,29 +550,41 @@ func NewGatewayService( deferredService *DeferredService, claudeTokenProvider *ClaudeTokenProvider, sessionLimitCache SessionLimitCache, + rpmCache RPMCache, digestStore *DigestSessionStore, ) *GatewayService { - return &GatewayService{ - accountRepo: accountRepo, - groupRepo: groupRepo, - usageLogRepo: usageLogRepo, - userRepo: userRepo, - userSubRepo: userSubRepo, - userGroupRateRepo: userGroupRateRepo, - cache: cache, - digestStore: digestStore, - cfg: cfg, - schedulerSnapshot: schedulerSnapshot, - concurrencyService: concurrencyService, - billingService: billingService, - rateLimitService: rateLimitService, - billingCacheService: billingCacheService, - identityService: identityService, - httpUpstream: httpUpstream, - deferredService: deferredService, - claudeTokenProvider: claudeTokenProvider, - sessionLimitCache: sessionLimitCache, + userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) + modelsListTTL := resolveModelsListCacheTTL(cfg) + + svc := &GatewayService{ + accountRepo: accountRepo, + groupRepo: groupRepo, + usageLogRepo: usageLogRepo, + userRepo: userRepo, + userSubRepo: userSubRepo, + userGroupRateRepo: userGroupRateRepo, + cache: cache, + digestStore: digestStore, + cfg: cfg, + schedulerSnapshot: schedulerSnapshot, + concurrencyService: concurrencyService, + billingService: billingService, + rateLimitService: rateLimitService, + billingCacheService: billingCacheService, + identityService: identityService, + httpUpstream: httpUpstream, + deferredService: deferredService, + claudeTokenProvider: claudeTokenProvider, + sessionLimitCache: sessionLimitCache, + rpmCache: rpmCache, + userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute), + modelsListCache: gocache.New(modelsListTTL, time.Minute), + modelsListCacheTTL: modelsListTTL, + responseHeaderFilter: compileResponseHeaderFilter(cfg), } + svc.debugModelRouting.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING"))) + svc.debugClaudeMimic.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC"))) + return svc } // GenerateSessionHash 从预解析请求计算粘性会话 hash @@ -929,13 +1055,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro cfg := s.schedulingConfig() - var stickyAccountID int64 - if sessionHash != "" && s.cache != nil { - if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil { - stickyAccountID = accountID - } - } - // 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组) group, groupID, err := s.checkClaudeCodeRestriction(ctx, groupID) if err != nil { @@ -943,12 +1062,21 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } ctx = s.withGroupContext(ctx, group) + var stickyAccountID int64 + if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 { + stickyAccountID = prefetch + } else if sessionHash != "" && s.cache != nil { + if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil { + stickyAccountID = accountID + } + } + if s.debugModelRoutingEnabled() && requestedModel != "" { groupPlatform := "" if group != nil { groupPlatform = group.Platform } - log.Printf("[ModelRoutingDebug] select entry: group_id=%v group_platform=%s model=%s session=%s sticky_account=%d load_batch=%v concurrency=%v", + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] select entry: group_id=%v group_platform=%s model=%s session=%s sticky_account=%d load_batch=%v concurrency=%v", derefGroupID(groupID), groupPlatform, requestedModel, shortSessionHash(sessionHash), stickyAccountID, cfg.LoadBatchEnabled, s.concurrencyService != nil) } @@ -1018,7 +1146,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } preferOAuth := platform == PlatformGemini if s.debugModelRoutingEnabled() && platform == PlatformAnthropic && requestedModel != "" { - log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform) } accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) @@ -1028,6 +1156,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if len(accounts) == 0 { return nil, errors.New("no available accounts") } + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) isExcluded := func(accountID int64) bool { if excludedIDs == nil { @@ -1048,7 +1178,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic { routingAccountIDs = group.GetRoutingAccountIDs(requestedModel) if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] context group routing: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v session=%s sticky_account=%d", + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] context group routing: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v session=%s sticky_account=%d", group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), routingAccountIDs, shortSessionHash(sessionHash), stickyAccountID) if len(routingAccountIDs) == 0 && group.ModelRoutingEnabled && len(group.ModelRouting) > 0 { keys := make([]string, 0, len(group.ModelRouting)) @@ -1060,7 +1190,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if len(keys) > maxKeys { keys = keys[:maxKeys] } - log.Printf("[ModelRoutingDebug] context group routing miss: group_id=%d model=%s patterns(sample)=%v", group.ID, requestedModel, keys) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] context group routing miss: group_id=%d model=%s patterns(sample)=%v", group.ID, requestedModel, keys) } } } @@ -1077,7 +1207,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro continue } account, ok := accountByID[routingAccountID] - if !ok || !account.IsSchedulable() { + if !ok || !s.isAccountSchedulableForSelection(account) { if !ok { filteredMissing++ } else { @@ -1093,7 +1223,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro filteredModelMapping++ continue } - if !account.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { filteredModelScope++ modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID) continue @@ -1103,31 +1233,36 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro filteredWindowCost++ continue } + // RPM 检查(非粘性会话路径) + if !s.isAccountSchedulableForRPM(ctx, account, false) { + continue + } routingCandidates = append(routingCandidates, account) } if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)", + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)", derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates), filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost) if len(modelScopeSkippedIDs) > 0 { - log.Printf("[ModelRoutingDebug] model_rate_limited accounts skipped: group_id=%v model=%s account_ids=%v", + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] model_rate_limited accounts skipped: group_id=%v model=%s account_ids=%v", derefGroupID(groupID), requestedModel, modelScopeSkippedIDs) } } if len(routingCandidates) > 0 { // 1.5. 在路由账号范围内检查粘性会话 - if sessionHash != "" && s.cache != nil { - stickyAccountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) - if err == nil && stickyAccountID > 0 && containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { + if sessionHash != "" && stickyAccountID > 0 { + if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { // 粘性账号在路由列表中,优先使用 if stickyAccount, ok := accountByID[stickyAccountID]; ok { - if stickyAccount.IsSchedulable() && + if s.isAccountSchedulableForSelection(stickyAccount) && s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && - stickyAccount.IsSchedulableForModelWithContext(ctx, requestedModel) && - s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查 + s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) && + s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) && + + s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查 result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 @@ -1136,7 +1271,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro // 继续到负载感知选择 } else { if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) } return &AccountSelectionResult{ Account: stickyAccount, @@ -1229,7 +1364,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL) } if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } return &AccountSelectionResult{ Account: item.account, @@ -1246,7 +1381,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro continue // 会话限制已满,尝试下一个 } if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } return &AccountSelectionResult{ Account: item.account, @@ -1261,14 +1396,14 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro // 所有路由账号会话限制都已满,继续到 Layer 2 回退 } // 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退 - log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel) + logger.LegacyPrintf("service.gateway", "[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel) } } // ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============ - if len(routingAccountIDs) == 0 && sessionHash != "" && s.cache != nil { - accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) - if err == nil && accountID > 0 && !isExcluded(accountID) { + if len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) { + accountID := stickyAccountID + if accountID > 0 && !isExcluded(accountID) { account, ok := accountByID[accountID] if ok { // 检查账户是否需要清理粘性会话绑定 @@ -1280,8 +1415,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !clearSticky && s.isAccountInGroup(account, groupID) && s.isAccountAllowedForPlatform(account, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && - account.IsSchedulableForModelWithContext(ctx, requestedModel) && - s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查 + s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && + s.isAccountSchedulableForWindowCost(ctx, account, true) && + + s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查 result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 @@ -1331,7 +1468,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro // Scheduler snapshots can be temporarily stale (bucket rebuild is throttled); // re-check schedulability here so recently rate-limited/overloaded accounts // are not selected again before the bucket is rebuilt. - if !acc.IsSchedulable() { + if !s.isAccountSchedulableForSelection(acc) { continue } if !s.isAccountAllowedForPlatform(acc, platform, useMixed) { @@ -1340,13 +1477,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } // 窗口费用检查(非粘性会话路径) if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { continue } + // RPM 检查(非粘性会话路径) + if !s.isAccountSchedulableForRPM(ctx, acc, false) { + continue + } candidates = append(candidates, acc) } @@ -1522,20 +1663,20 @@ func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupI group, err := s.resolveGroupByID(ctx, *groupID) if err != nil || group == nil { if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] resolve group failed: group_id=%v model=%s platform=%s err=%v", derefGroupID(groupID), requestedModel, platform, err) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] resolve group failed: group_id=%v model=%s platform=%s err=%v", derefGroupID(groupID), requestedModel, platform, err) } return nil } // Preserve existing behavior: model routing only applies to anthropic groups. if group.Platform != PlatformAnthropic { if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] skip: non-anthropic group platform: group_id=%d group_platform=%s model=%s", group.ID, group.Platform, requestedModel) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] skip: non-anthropic group platform: group_id=%d group_platform=%s model=%s", group.ID, group.Platform, requestedModel) } return nil } ids := group.GetRoutingAccountIDs(requestedModel) if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] routing lookup: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v", + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routing lookup: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v", group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), ids) } return ids @@ -1611,6 +1752,9 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr } func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { + if platform == PlatformSora { + return s.listSoraSchedulableAccounts(ctx, groupID) + } if s.schedulerSnapshot != nil { accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) if err == nil { @@ -1638,8 +1782,10 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i var err error if groupID != nil { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms) - } else { + } else if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) + } else { + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, platforms) } if err != nil { slog.Debug("account_scheduling_list_failed", @@ -1680,7 +1826,7 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform) // 分组内无账号则返回空列表,由上层处理错误,不再回退到全平台查询 } else { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, platform) } if err != nil { slog.Debug("account_scheduling_list_failed", @@ -1705,6 +1851,53 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i return accounts, useMixed, nil } +func (s *GatewayService) listSoraSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, bool, error) { + const useMixed = false + + var accounts []Account + var err error + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) + } else if groupID != nil { + accounts, err = s.accountRepo.ListByGroup(ctx, *groupID) + } else { + accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) + } + if err != nil { + slog.Debug("account_scheduling_list_failed", + "group_id", derefGroupID(groupID), + "platform", PlatformSora, + "error", err) + return nil, useMixed, err + } + + filtered := make([]Account, 0, len(accounts)) + for _, acc := range accounts { + if acc.Platform != PlatformSora { + continue + } + if !s.isSoraAccountSchedulable(&acc) { + continue + } + filtered = append(filtered, acc) + } + slog.Debug("account_scheduling_list_sora", + "group_id", derefGroupID(groupID), + "platform", PlatformSora, + "raw_count", len(accounts), + "filtered_count", len(filtered)) + for _, acc := range filtered { + slog.Debug("account_scheduling_account_detail", + "account_id", acc.ID, + "name", acc.Name, + "platform", acc.Platform, + "type", acc.Type, + "status", acc.Status, + "tls_fingerprint", acc.IsTLSFingerprintEnabled()) + } + return filtered, useMixed, nil +} + // IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。 // 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context, // 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。 @@ -1729,15 +1922,59 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform return account.Platform == platform } -// isAccountInGroup checks if the account belongs to the specified group. -// Returns true if groupID is nil (no group restriction) or account belongs to the group. -func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool { - if groupID == nil { - return true // 无分组限制 +func (s *GatewayService) isSoraAccountSchedulable(account *Account) bool { + return s.soraUnschedulableReason(account) == "" +} + +func (s *GatewayService) soraUnschedulableReason(account *Account) string { + if account == nil { + return "account_nil" } + if account.Status != StatusActive { + return fmt.Sprintf("status=%s", account.Status) + } + if !account.Schedulable { + return "schedulable=false" + } + if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { + return fmt.Sprintf("temp_unschedulable_until=%s", account.TempUnschedulableUntil.UTC().Format(time.RFC3339)) + } + return "" +} + +func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool { if account == nil { return false } + if account.Platform == PlatformSora { + return s.isSoraAccountSchedulable(account) + } + return account.IsSchedulable() +} + +func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Context, account *Account, requestedModel string) bool { + if account == nil { + return false + } + if account.Platform == PlatformSora { + if !s.isSoraAccountSchedulable(account) { + return false + } + return account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel) <= 0 + } + return account.IsSchedulableForModelWithContext(ctx, requestedModel) +} + +// isAccountInGroup checks if the account belongs to the specified group. +// When groupID is nil, returns true only for ungrouped accounts (no group assignments). +func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool { + if account == nil { + return false + } + if groupID == nil { + // 无分组的 API Key 只能使用未分组的账号 + return len(account.AccountGroups) == 0 + } for _, ag := range account.AccountGroups { if ag.GroupID == *groupID { return true @@ -1753,6 +1990,129 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) } +type usageLogWindowStatsBatchProvider interface { + GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) +} + +type windowCostPrefetchContextKeyType struct{} + +var windowCostPrefetchContextKey = windowCostPrefetchContextKeyType{} + +func windowCostFromPrefetchContext(ctx context.Context, accountID int64) (float64, bool) { + if ctx == nil || accountID <= 0 { + return 0, false + } + m, ok := ctx.Value(windowCostPrefetchContextKey).(map[int64]float64) + if !ok || len(m) == 0 { + return 0, false + } + v, exists := m[accountID] + return v, exists +} + +func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []Account) context.Context { + if ctx == nil || len(accounts) == 0 || s.sessionLimitCache == nil || s.usageLogRepo == nil { + return ctx + } + + accountByID := make(map[int64]*Account) + accountIDs := make([]int64, 0, len(accounts)) + for i := range accounts { + account := &accounts[i] + if account == nil || !account.IsAnthropicOAuthOrSetupToken() { + continue + } + if account.GetWindowCostLimit() <= 0 { + continue + } + accountByID[account.ID] = account + accountIDs = append(accountIDs, account.ID) + } + if len(accountIDs) == 0 { + return ctx + } + + costs := make(map[int64]float64, len(accountIDs)) + cacheValues, err := s.sessionLimitCache.GetWindowCostBatch(ctx, accountIDs) + if err == nil { + for accountID, cost := range cacheValues { + costs[accountID] = cost + } + windowCostPrefetchCacheHitTotal.Add(int64(len(cacheValues))) + } else { + windowCostPrefetchErrorTotal.Add(1) + logger.LegacyPrintf("service.gateway", "window_cost batch cache read failed: %v", err) + } + cacheMissCount := len(accountIDs) - len(costs) + if cacheMissCount < 0 { + cacheMissCount = 0 + } + windowCostPrefetchCacheMissTotal.Add(int64(cacheMissCount)) + + missingByStart := make(map[int64][]int64) + startTimes := make(map[int64]time.Time) + for _, accountID := range accountIDs { + if _, ok := costs[accountID]; ok { + continue + } + account := accountByID[accountID] + if account == nil { + continue + } + startTime := account.GetCurrentWindowStartTime() + startKey := startTime.Unix() + missingByStart[startKey] = append(missingByStart[startKey], accountID) + startTimes[startKey] = startTime + } + if len(missingByStart) == 0 { + return context.WithValue(ctx, windowCostPrefetchContextKey, costs) + } + + batchReader, hasBatch := s.usageLogRepo.(usageLogWindowStatsBatchProvider) + for startKey, ids := range missingByStart { + startTime := startTimes[startKey] + + if hasBatch { + windowCostPrefetchBatchSQLTotal.Add(1) + queryStart := time.Now() + statsByAccount, err := batchReader.GetAccountWindowStatsBatch(ctx, ids, startTime) + if err == nil { + slog.Debug("window_cost_batch_query_ok", + "accounts", len(ids), + "window_start", startTime.Format(time.RFC3339), + "duration_ms", time.Since(queryStart).Milliseconds()) + for _, accountID := range ids { + stats := statsByAccount[accountID] + cost := 0.0 + if stats != nil { + cost = stats.StandardCost + } + costs[accountID] = cost + _ = s.sessionLimitCache.SetWindowCost(ctx, accountID, cost) + } + continue + } + windowCostPrefetchErrorTotal.Add(1) + logger.LegacyPrintf("service.gateway", "window_cost batch db query failed: start=%s err=%v", startTime.Format(time.RFC3339), err) + } + + // 回退路径:缺少批量仓储能力或批量查询失败时,按账号单查(失败开放)。 + windowCostPrefetchFallbackTotal.Add(int64(len(ids))) + for _, accountID := range ids { + stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, accountID, startTime) + if err != nil { + windowCostPrefetchErrorTotal.Add(1) + continue + } + cost := stats.StandardCost + costs[accountID] = cost + _ = s.sessionLimitCache.SetWindowCost(ctx, accountID, cost) + } + } + + return context.WithValue(ctx, windowCostPrefetchContextKey, costs) +} + // isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度 // 仅适用于 Anthropic OAuth/SetupToken 账号 // 返回 true 表示可调度,false 表示不可调度 @@ -1769,6 +2129,10 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context, // 尝试从缓存获取窗口费用 var currentCost float64 + if cost, ok := windowCostFromPrefetchContext(ctx, account.ID); ok { + currentCost = cost + goto checkSchedulability + } if s.sessionLimitCache != nil { if cost, hit, err := s.sessionLimitCache.GetWindowCost(ctx, account.ID); err == nil && hit { currentCost = cost @@ -1810,6 +2174,88 @@ checkSchedulability: return true } +// rpmPrefetchContextKey is the context key for prefetched RPM counts. +type rpmPrefetchContextKeyType struct{} + +var rpmPrefetchContextKey = rpmPrefetchContextKeyType{} + +func rpmFromPrefetchContext(ctx context.Context, accountID int64) (int, bool) { + if v, ok := ctx.Value(rpmPrefetchContextKey).(map[int64]int); ok { + count, found := v[accountID] + return count, found + } + return 0, false +} + +// withRPMPrefetch 批量预取所有候选账号的 RPM 计数 +func (s *GatewayService) withRPMPrefetch(ctx context.Context, accounts []Account) context.Context { + if s.rpmCache == nil { + return ctx + } + + var ids []int64 + for i := range accounts { + if accounts[i].IsAnthropicOAuthOrSetupToken() && accounts[i].GetBaseRPM() > 0 { + ids = append(ids, accounts[i].ID) + } + } + if len(ids) == 0 { + return ctx + } + + counts, err := s.rpmCache.GetRPMBatch(ctx, ids) + if err != nil { + return ctx // 失败开放 + } + return context.WithValue(ctx, rpmPrefetchContextKey, counts) +} + +// isAccountSchedulableForRPM 检查账号是否可根据 RPM 进行调度 +// 仅适用于 Anthropic OAuth/SetupToken 账号 +func (s *GatewayService) isAccountSchedulableForRPM(ctx context.Context, account *Account, isSticky bool) bool { + if !account.IsAnthropicOAuthOrSetupToken() { + return true + } + baseRPM := account.GetBaseRPM() + if baseRPM <= 0 { + return true + } + + // 尝试从预取缓存获取 + var currentRPM int + if count, ok := rpmFromPrefetchContext(ctx, account.ID); ok { + currentRPM = count + } else if s.rpmCache != nil { + if count, err := s.rpmCache.GetRPM(ctx, account.ID); err == nil { + currentRPM = count + } + // 失败开放:GetRPM 错误时允许调度 + } + + schedulability := account.CheckRPMSchedulability(currentRPM) + switch schedulability { + case WindowCostSchedulable: + return true + case WindowCostStickyOnly: + return isSticky + case WindowCostNotSchedulable: + return false + } + return true +} + +// IncrementAccountRPM increments the RPM counter for the given account. +// 已知 TOCTOU 竞态:调度时读取 RPM 计数与此处递增之间存在时间窗口, +// 高并发下可能短暂超出 RPM 限制。这是与 WindowCost 一致的 soft-limit +// 设计权衡——可接受的少量超额优于加锁带来的延迟和复杂度。 +func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int64) error { + if s.rpmCache == nil { + return nil + } + _, err := s.rpmCache.IncrementRPM(ctx, accountID) + return err +} + // checkAndRegisterSession 检查并注册会话,用于会话数量限制 // 仅适用于 Anthropic OAuth/SetupToken 账号 // sessionID: 会话标识符(使用粘性会话的 hash) @@ -1966,7 +2412,7 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { return a.LastUsedAt.Before(*b.LastUsedAt) } }) - shuffleWithinPriorityAndLastUsed(accounts) + shuffleWithinPriorityAndLastUsed(accounts, preferOAuth) } // shuffleWithinSortGroups 对排序后的 accountWithLoad 切片,按 (Priority, LoadRate, LastUsedAt) 分组后组内随机打乱。 @@ -2002,7 +2448,12 @@ func sameAccountWithLoadGroup(a, b accountWithLoad) bool { } // shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。 -func shuffleWithinPriorityAndLastUsed(accounts []*Account) { +// +// 注意:当 preferOAuth=true 时,需要保证 OAuth 账号在同组内仍然优先,否则会把排序时的偏好打散掉。 +// 因此这里采用"组内分区 + 分区内 shuffle"的方式: +// - 先把同组账号按 (OAuth / 非 OAuth) 拆成两段,保持 OAuth 段在前; +// - 再分别在各段内随机打散,避免热点。 +func shuffleWithinPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { if len(accounts) <= 1 { return } @@ -2013,9 +2464,29 @@ func shuffleWithinPriorityAndLastUsed(accounts []*Account) { j++ } if j-i > 1 { - mathrand.Shuffle(j-i, func(a, b int) { - accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a] - }) + if preferOAuth { + oauth := make([]*Account, 0, j-i) + others := make([]*Account, 0, j-i) + for _, acc := range accounts[i:j] { + if acc.Type == AccountTypeOAuth { + oauth = append(oauth, acc) + } else { + others = append(others, acc) + } + } + if len(oauth) > 1 { + mathrand.Shuffle(len(oauth), func(a, b int) { oauth[a], oauth[b] = oauth[b], oauth[a] }) + } + if len(others) > 1 { + mathrand.Shuffle(len(others), func(a, b int) { others[a], others[b] = others[b], others[a] }) + } + copy(accounts[i:], oauth) + copy(accounts[i+len(oauth):], others) + } else { + mathrand.Shuffle(j-i, func(a, b int) { + accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a] + }) + } } i = j } @@ -2104,7 +2575,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, // so switching model can switch upstream account within the same sticky session. if len(routingAccountIDs) > 0 { if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] legacy routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", derefGroupID(groupID), requestedModel, platform, shortSessionHash(sessionHash), routingAccountIDs) } // 1) Sticky session only applies if the bound account is within the routing set. @@ -2119,9 +2590,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } return account, nil } @@ -2142,6 +2613,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } accountsLoaded = true + // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存 + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + routingSet := make(map[int64]struct{}, len(routingAccountIDs)) for _, id := range routingAccountIDs { if id > 0 { @@ -2160,13 +2635,19 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } // Scheduler snapshots can be temporarily stale; re-check schedulability here to // avoid selecting accounts that were recently rate-limited/overloaded. - if !acc.IsSchedulable() { + if !s.isAccountSchedulableForSelection(acc) { continue } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + if !s.isAccountSchedulableForRPM(ctx, acc, false) { continue } if selected == nil { @@ -2196,15 +2677,15 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if selected != nil { if sessionHash != "" && s.cache != nil { if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { - log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } } if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] legacy routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) } return selected, nil } - log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) + logger.LegacyPrintf("service.gateway", "[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) } // 1. 查询粘性会话 @@ -2219,7 +2700,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { return account, nil } } @@ -2240,6 +2721,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } } + // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1) + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + // 3. 按优先级+最久未用选择(考虑模型支持) var selected *Account for i := range accounts { @@ -2249,13 +2734,19 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } // Scheduler snapshots can be temporarily stale; re-check schedulability here to // avoid selecting accounts that were recently rate-limited/overloaded. - if !acc.IsSchedulable() { + if !s.isAccountSchedulableForSelection(acc) { continue } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + if !s.isAccountSchedulableForRPM(ctx, acc, false) { continue } if selected == nil { @@ -2283,8 +2774,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } if selected == nil { + stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false) if requestedModel != "" { - return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel) + return nil, fmt.Errorf("no available accounts supporting model: %s (%s)", requestedModel, summarizeSelectionFailureStats(stats)) } return nil, errors.New("no available accounts") } @@ -2292,7 +2784,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, // 4. 建立粘性绑定 if sessionHash != "" && s.cache != nil { if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { - log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } } @@ -2311,7 +2803,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g // ============ Model Routing (legacy path): apply before sticky session ============ if len(routingAccountIDs) > 0 { if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] legacy mixed routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", derefGroupID(groupID), requestedModel, nativePlatform, shortSessionHash(sessionHash), routingAccountIDs) } // 1) Sticky session only applies if the bound account is within the routing set. @@ -2326,10 +2818,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } return account, nil } @@ -2347,6 +2839,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } accountsLoaded = true + // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存 + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + routingSet := make(map[int64]struct{}, len(routingAccountIDs)) for _, id := range routingAccountIDs { if id > 0 { @@ -2365,7 +2861,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } // Scheduler snapshots can be temporarily stale; re-check schedulability here to // avoid selecting accounts that were recently rate-limited/overloaded. - if !acc.IsSchedulable() { + if !s.isAccountSchedulableForSelection(acc) { continue } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 @@ -2375,7 +2871,13 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + if !s.isAccountSchedulableForRPM(ctx, acc, false) { continue } if selected == nil { @@ -2405,15 +2907,15 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if selected != nil { if sessionHash != "" && s.cache != nil { if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { - log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } } if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] legacy mixed routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) } return selected, nil } - log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) + logger.LegacyPrintf("service.gateway", "[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) } // 1. 查询粘性会话 @@ -2428,7 +2930,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { return account, nil } @@ -2447,6 +2949,10 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } } + // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1) + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) var selected *Account for i := range accounts { @@ -2456,7 +2962,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } // Scheduler snapshots can be temporarily stale; re-check schedulability here to // avoid selecting accounts that were recently rate-limited/overloaded. - if !acc.IsSchedulable() { + if !s.isAccountSchedulableForSelection(acc) { continue } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 @@ -2466,7 +2972,13 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + if !s.isAccountSchedulableForRPM(ctx, acc, false) { continue } if selected == nil { @@ -2494,8 +3006,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } if selected == nil { + stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, nativePlatform, accounts, excludedIDs, true) if requestedModel != "" { - return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel) + return nil, fmt.Errorf("no available accounts supporting model: %s (%s)", requestedModel, summarizeSelectionFailureStats(stats)) } return nil, errors.New("no available accounts") } @@ -2503,13 +3016,243 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g // 4. 建立粘性绑定 if sessionHash != "" && s.cache != nil { if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { - log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } } return selected, nil } +type selectionFailureStats struct { + Total int + Eligible int + Excluded int + Unschedulable int + PlatformFiltered int + ModelUnsupported int + ModelRateLimited int + SamplePlatformIDs []int64 + SampleMappingIDs []int64 + SampleRateLimitIDs []string +} + +type selectionFailureDiagnosis struct { + Category string + Detail string +} + +func (s *GatewayService) logDetailedSelectionFailure( + ctx context.Context, + groupID *int64, + sessionHash string, + requestedModel string, + platform string, + accounts []Account, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) selectionFailureStats { + stats := s.collectSelectionFailureStats(ctx, accounts, requestedModel, platform, excludedIDs, allowMixedScheduling) + logger.LegacyPrintf( + "service.gateway", + "[SelectAccountDetailed] group_id=%v model=%s platform=%s session=%s total=%d eligible=%d excluded=%d unschedulable=%d platform_filtered=%d model_unsupported=%d model_rate_limited=%d sample_platform_filtered=%v sample_model_unsupported=%v sample_model_rate_limited=%v", + derefGroupID(groupID), + requestedModel, + platform, + shortSessionHash(sessionHash), + stats.Total, + stats.Eligible, + stats.Excluded, + stats.Unschedulable, + stats.PlatformFiltered, + stats.ModelUnsupported, + stats.ModelRateLimited, + stats.SamplePlatformIDs, + stats.SampleMappingIDs, + stats.SampleRateLimitIDs, + ) + if platform == PlatformSora { + s.logSoraSelectionFailureDetails(ctx, groupID, sessionHash, requestedModel, accounts, excludedIDs, allowMixedScheduling) + } + return stats +} + +func (s *GatewayService) collectSelectionFailureStats( + ctx context.Context, + accounts []Account, + requestedModel string, + platform string, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) selectionFailureStats { + stats := selectionFailureStats{ + Total: len(accounts), + } + + for i := range accounts { + acc := &accounts[i] + diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, platform, excludedIDs, allowMixedScheduling) + switch diagnosis.Category { + case "excluded": + stats.Excluded++ + case "unschedulable": + stats.Unschedulable++ + case "platform_filtered": + stats.PlatformFiltered++ + stats.SamplePlatformIDs = appendSelectionFailureSampleID(stats.SamplePlatformIDs, acc.ID) + case "model_unsupported": + stats.ModelUnsupported++ + stats.SampleMappingIDs = appendSelectionFailureSampleID(stats.SampleMappingIDs, acc.ID) + case "model_rate_limited": + stats.ModelRateLimited++ + remaining := acc.GetRateLimitRemainingTimeWithContext(ctx, requestedModel).Truncate(time.Second) + stats.SampleRateLimitIDs = appendSelectionFailureRateSample(stats.SampleRateLimitIDs, acc.ID, remaining) + default: + stats.Eligible++ + } + } + + return stats +} + +func (s *GatewayService) diagnoseSelectionFailure( + ctx context.Context, + acc *Account, + requestedModel string, + platform string, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) selectionFailureDiagnosis { + if acc == nil { + return selectionFailureDiagnosis{Category: "unschedulable", Detail: "account_nil"} + } + if _, excluded := excludedIDs[acc.ID]; excluded { + return selectionFailureDiagnosis{Category: "excluded"} + } + if !s.isAccountSchedulableForSelection(acc) { + detail := "generic_unschedulable" + if acc.Platform == PlatformSora { + detail = s.soraUnschedulableReason(acc) + } + return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail} + } + if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) { + return selectionFailureDiagnosis{ + Category: "platform_filtered", + Detail: fmt.Sprintf("account_platform=%s requested_platform=%s", acc.Platform, strings.TrimSpace(platform)), + } + } + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { + return selectionFailureDiagnosis{ + Category: "model_unsupported", + Detail: fmt.Sprintf("model=%s", requestedModel), + } + } + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + remaining := acc.GetRateLimitRemainingTimeWithContext(ctx, requestedModel).Truncate(time.Second) + return selectionFailureDiagnosis{ + Category: "model_rate_limited", + Detail: fmt.Sprintf("remaining=%s", remaining), + } + } + return selectionFailureDiagnosis{Category: "eligible"} +} + +func (s *GatewayService) logSoraSelectionFailureDetails( + ctx context.Context, + groupID *int64, + sessionHash string, + requestedModel string, + accounts []Account, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) { + const maxLines = 30 + logged := 0 + + for i := range accounts { + if logged >= maxLines { + break + } + acc := &accounts[i] + diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, PlatformSora, excludedIDs, allowMixedScheduling) + if diagnosis.Category == "eligible" { + continue + } + detail := diagnosis.Detail + if detail == "" { + detail = "-" + } + logger.LegacyPrintf( + "service.gateway", + "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s", + derefGroupID(groupID), + requestedModel, + shortSessionHash(sessionHash), + acc.ID, + acc.Platform, + diagnosis.Category, + detail, + ) + logged++ + } + if len(accounts) > maxLines { + logger.LegacyPrintf( + "service.gateway", + "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d", + derefGroupID(groupID), + requestedModel, + shortSessionHash(sessionHash), + len(accounts), + logged, + ) + } +} + +func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool { + if acc == nil { + return true + } + if allowMixedScheduling { + if acc.Platform == PlatformAntigravity { + return !acc.IsMixedSchedulingEnabled() + } + return acc.Platform != platform + } + if strings.TrimSpace(platform) == "" { + return false + } + return acc.Platform != platform +} + +func appendSelectionFailureSampleID(samples []int64, id int64) []int64 { + const limit = 5 + if len(samples) >= limit { + return samples + } + return append(samples, id) +} + +func appendSelectionFailureRateSample(samples []string, accountID int64, remaining time.Duration) []string { + const limit = 5 + if len(samples) >= limit { + return samples + } + return append(samples, fmt.Sprintf("%d(%s)", accountID, remaining)) +} + +func summarizeSelectionFailureStats(stats selectionFailureStats) string { + return fmt.Sprintf( + "total=%d eligible=%d excluded=%d unschedulable=%d platform_filtered=%d model_unsupported=%d model_rate_limited=%d", + stats.Total, + stats.Eligible, + stats.Excluded, + stats.Unschedulable, + stats.PlatformFiltered, + stats.ModelUnsupported, + stats.ModelRateLimited, + ) +} + // isModelSupportedByAccountWithContext 根据账户平台检查模型支持(带 context) // 对于 Antigravity 平台,会先获取映射后的最终模型名(包括 thinking 后缀)再检查支持 func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Context, account *Account, requestedModel string) bool { @@ -2523,7 +3266,7 @@ func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Contex return false } // 应用 thinking 后缀后检查最终模型是否在账号映射中 - if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok { + if enabled, ok := ThinkingEnabledFromContext(ctx); ok { finalModel := applyThinkingModelSuffix(mapped, enabled) if finalModel == mapped { return true // thinking 后缀未改变模型名,映射已通过 @@ -2543,18 +3286,154 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo } return mapAntigravityModel(account, requestedModel) != "" } + if account.Platform == PlatformSora { + return s.isSoraModelSupportedByAccount(account, requestedModel) + } // OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID) if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { requestedModel = claude.NormalizeModelID(requestedModel) } - // Gemini API Key 账户直接透传,由上游判断模型是否支持 - if account.Platform == PlatformGemini && account.Type == AccountTypeAPIKey { - return true - } // 其他平台使用账户的模型支持检查 return account.IsModelSupported(requestedModel) } +func (s *GatewayService) isSoraModelSupportedByAccount(account *Account, requestedModel string) bool { + if account == nil { + return false + } + if strings.TrimSpace(requestedModel) == "" { + return true + } + + // 先走原始精确/通配符匹配。 + mapping := account.GetModelMapping() + if len(mapping) == 0 || account.IsModelSupported(requestedModel) { + return true + } + + aliases := buildSoraModelAliases(requestedModel) + if len(aliases) == 0 { + return false + } + + hasSoraSelector := false + for pattern := range mapping { + if !isSoraModelSelector(pattern) { + continue + } + hasSoraSelector = true + if matchPatternAnyAlias(pattern, aliases) { + return true + } + } + + // 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*), + // 此时不应误拦截 Sora 模型请求。 + if !hasSoraSelector { + return true + } + + return false +} + +func matchPatternAnyAlias(pattern string, aliases []string) bool { + normalizedPattern := strings.ToLower(strings.TrimSpace(pattern)) + if normalizedPattern == "" { + return false + } + for _, alias := range aliases { + if matchWildcard(normalizedPattern, alias) { + return true + } + } + return false +} + +func isSoraModelSelector(pattern string) bool { + p := strings.ToLower(strings.TrimSpace(pattern)) + if p == "" { + return false + } + + switch { + case strings.HasPrefix(p, "sora"), + strings.HasPrefix(p, "gpt-image"), + strings.HasPrefix(p, "prompt-enhance"), + strings.HasPrefix(p, "sy_"): + return true + } + + return p == "video" || p == "image" +} + +func buildSoraModelAliases(requestedModel string) []string { + modelID := strings.ToLower(strings.TrimSpace(requestedModel)) + if modelID == "" { + return nil + } + + aliases := make([]string, 0, 8) + addAlias := func(value string) { + v := strings.ToLower(strings.TrimSpace(value)) + if v == "" { + return + } + for _, existing := range aliases { + if existing == v { + return + } + } + aliases = append(aliases, v) + } + + addAlias(modelID) + cfg, ok := GetSoraModelConfig(modelID) + if ok { + addAlias(cfg.Model) + switch cfg.Type { + case "video": + addAlias("video") + addAlias("sora") + addAlias(soraVideoFamilyAlias(modelID)) + case "image": + addAlias("image") + addAlias("gpt-image") + case "prompt_enhance": + addAlias("prompt-enhance") + } + return aliases + } + + switch { + case strings.HasPrefix(modelID, "sora"): + addAlias("video") + addAlias("sora") + addAlias(soraVideoFamilyAlias(modelID)) + case strings.HasPrefix(modelID, "gpt-image"): + addAlias("image") + addAlias("gpt-image") + case strings.HasPrefix(modelID, "prompt-enhance"): + addAlias("prompt-enhance") + default: + return nil + } + + return aliases +} + +func soraVideoFamilyAlias(modelID string) string { + switch { + case strings.HasPrefix(modelID, "sora2pro-hd"): + return "sora2pro-hd" + case strings.HasPrefix(modelID, "sora2pro"): + return "sora2pro" + case strings.HasPrefix(modelID, "sora2"): + return "sora2" + default: + return "" + } +} + // GetAccessToken 获取账号凭证 func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { switch account.Type { @@ -2818,7 +3697,7 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { result, err := sjson.SetBytes(body, "system", newSystem) if err != nil { - log.Printf("Warning: failed to inject Claude Code prompt: %v", err) + logger.LegacyPrintf("service.gateway", "Warning: failed to inject Claude Code prompt: %v", err) return body } return result @@ -2974,7 +3853,7 @@ func removeCacheControlFromThinkingBlocks(data map[string]any) { if blockType, _ := m["type"].(string); blockType == "thinking" { if _, has := m["cache_control"]; has { delete(m, "cache_control") - log.Printf("[Warning] Removed illegal cache_control from thinking block in system") + logger.LegacyPrintf("service.gateway", "[Warning] Removed illegal cache_control from thinking block in system") } } } @@ -2991,7 +3870,7 @@ func removeCacheControlFromThinkingBlocks(data map[string]any) { if blockType, _ := m["type"].(string); blockType == "thinking" { if _, has := m["cache_control"]; has { delete(m, "cache_control") - log.Printf("[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIdx, contentIdx) + logger.LegacyPrintf("service.gateway", "[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIdx, contentIdx) } } } @@ -3009,6 +3888,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return nil, fmt.Errorf("parse request: empty request") } + if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { + return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body, parsed.Model, parsed.Stream, startTime) + } + body := parsed.Body reqModel := parsed.Model reqStream := parsed.Stream @@ -3070,7 +3953,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 替换请求体中的模型名 body = s.replaceModelInBody(body, mappedModel) reqModel = mappedModel - log.Printf("Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource) + logger.LegacyPrintf("service.gateway", "Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource) } // 获取凭证 @@ -3086,16 +3969,16 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } // 调试日志:记录即将转发的账号信息 - log.Printf("[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s", + logger.LegacyPrintf("service.gateway", "[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s", account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL) + // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 + setOpsUpstreamRequestBody(c, body) // 重试循环 var resp *http.Response retryStart := time.Now() for attempt := 1; attempt <= maxRetryAttempts; attempt++ { // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) - // Capture upstream request body for ops retry of this attempt. - c.Set(OpsUpstreamRequestBodyKey, string(body)) upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) if err != nil { return nil, err @@ -3166,7 +4049,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A resp.Body = io.NopCloser(bytes.NewReader(respBody)) break } - log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID) + logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID) // Conservative two-stage fallback: // 1) Disable thinking + thinking->text (preserve content) @@ -3179,7 +4062,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { if retryResp.StatusCode < 400 { - log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID) + logger.LegacyPrintf("service.gateway", "Account %d: signature error retry succeeded (thinking downgraded)", account.ID) resp = retryResp break } @@ -3204,7 +4087,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A }) msg2 := extractUpstreamErrorMessage(retryRespBody) if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed { - log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID) + logger.LegacyPrintf("service.gateway", "Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID) filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body) retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) if buildErr2 == nil { @@ -3224,9 +4107,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A Kind: "signature_retry_tools_request_error", Message: sanitizeUpstreamErrorMessage(retryErr2.Error()), }) - log.Printf("Account %d: tool-downgrade signature retry failed: %v", account.ID, retryErr2) + logger.LegacyPrintf("service.gateway", "Account %d: tool-downgrade signature retry failed: %v", account.ID, retryErr2) } else { - log.Printf("Account %d: tool-downgrade signature retry build failed: %v", account.ID, buildErr2) + logger.LegacyPrintf("service.gateway", "Account %d: tool-downgrade signature retry build failed: %v", account.ID, buildErr2) } } } @@ -3242,9 +4125,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A if retryResp != nil && retryResp.Body != nil { _ = retryResp.Body.Close() } - log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr) + logger.LegacyPrintf("service.gateway", "Account %d: signature error retry failed: %v", account.ID, retryErr) } else { - log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr) + logger.LegacyPrintf("service.gateway", "Account %d: signature error retry build request failed: %v", account.ID, buildErr) } // Retry failed: restore original response body and continue handling. @@ -3290,7 +4173,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return "" }(), }) - log.Printf("Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)", + logger.LegacyPrintf("service.gateway", "Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)", account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed) if err := sleepWithContext(ctx, delay); err != nil { return nil, err @@ -3303,10 +4186,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 不需要重试(成功或不可重试的错误),跳出循环 // DEBUG: 输出响应 headers(用于检测 rate limit 信息) - if account.Platform == PlatformGemini && resp.StatusCode < 400 { - log.Printf("[DEBUG] Gemini API Response Headers for account %d:", account.ID) + if account.Platform == PlatformGemini && resp.StatusCode < 400 && s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders { + logger.LegacyPrintf("service.gateway", "[DEBUG] Gemini API Response Headers for account %d:", account.ID) for k, v := range resp.Header { - log.Printf("[DEBUG] %s: %v", k, v) + logger.LegacyPrintf("service.gateway", "[DEBUG] %s: %v", k, v) } } break @@ -3324,7 +4207,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A resp.Body = io.NopCloser(bytes.NewReader(respBody)) // 调试日志:打印重试耗尽后的错误响应 - log.Printf("[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) s.handleRetryExhaustedSideEffects(ctx, resp, account) @@ -3355,7 +4238,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A resp.Body = io.NopCloser(bytes.NewReader(respBody)) // 调试日志:打印上游错误响应 - log.Printf("[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) s.handleFailoverSideEffects(ctx, resp, account) @@ -3409,13 +4292,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A }) if s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( + logger.LegacyPrintf("service.gateway", "Account %d: 400 error, attempting failover: %s", account.ID, truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), ) } else { - log.Printf("Account %d: 400 error, attempting failover", account.ID) + logger.LegacyPrintf("service.gateway", "Account %d: 400 error, attempting failover", account.ID) } s.handleFailoverSideEffects(ctx, resp, account) return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} @@ -3425,6 +4308,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } // 处理正常响应 + + // 触发上游接受回调(提前释放串行锁,不等流完成) + if parsed.OnUpstreamAccepted != nil { + parsed.OnUpstreamAccepted() + } + var usage *ClaudeUsage var firstTokenMs *int var clientDisconnect bool @@ -3459,6 +4348,602 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A }, nil } +func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + reqModel string, + reqStream bool, + startTime time.Time, +) (*ForwardResult, error) { + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + if tokenType != "apikey" { + return nil, fmt.Errorf("anthropic api key passthrough requires apikey token, got: %s", tokenType) + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + logger.LegacyPrintf("service.gateway", "[Anthropic 自动透传] 命中 API Key 透传分支: account=%d name=%s model=%s stream=%v", + account.ID, account.Name, reqModel, reqStream) + + if c != nil { + c.Set("anthropic_passthrough", true) + } + // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 + setOpsUpstreamRequestBody(c, body) + + var resp *http.Response + retryStart := time.Now() + for attempt := 1; attempt <= maxRetryAttempts; attempt++ { + upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(ctx, c, account, body, token) + if err != nil { + return nil, err + } + + resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Passthrough: true, + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + + // 透传分支禁止 400 请求体降级重试(该重试会改写请求体) + if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if attempt < maxRetryAttempts { + elapsed := time.Since(retryStart) + if elapsed >= maxRetryElapsed { + break + } + + delay := retryBackoffDelay(attempt) + remaining := maxRetryElapsed - elapsed + if delay > remaining { + delay = remaining + } + if delay <= 0 { + break + } + + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "retry", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + logger.LegacyPrintf("service.gateway", "Anthropic passthrough account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)", + account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed) + if err := sleepWithContext(ctx, delay); err != nil { + return nil, err + } + continue + } + break + } + + break + } + if resp == nil || resp.Body == nil { + return nil, errors.New("upstream request failed: empty response") + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + logger.LegacyPrintf("service.gateway", "[Anthropic Passthrough] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) + + s.handleRetryExhaustedSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "retry_exhausted_failover", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + return s.handleRetryExhaustedError(ctx, resp, c, account) + } + + if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + logger.LegacyPrintf("service.gateway", "[Anthropic Passthrough] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) + + s.handleFailoverSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "failover", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + + if resp.StatusCode >= 400 { + return s.handleErrorResponse(ctx, resp, c, account) + } + + var usage *ClaudeUsage + var firstTokenMs *int + var clientDisconnect bool + if reqStream { + streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, startTime, reqModel) + if err != nil { + return nil, err + } + usage = streamResult.usage + firstTokenMs = streamResult.firstTokenMs + clientDisconnect = streamResult.clientDisconnect + } else { + usage, err = s.handleNonStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account) + if err != nil { + return nil, err + } + } + if usage == nil { + usage = &ClaudeUsage{} + } + + return &ForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: reqModel, + Stream: reqStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, + }, nil +} + +func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + token string, +) (*http.Request, error) { + targetURL := claudeAPIURL + baseURL := account.GetBaseURL() + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = validatedURL + "/v1/messages" + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(strings.TrimSpace(key)) + if !allowedHeaders[lowerKey] { + continue + } + for _, v := range values { + req.Header.Add(key, v) + } + } + } + + // 覆盖入站鉴权残留,并注入上游认证 + req.Header.Del("authorization") + req.Header.Del("x-api-key") + req.Header.Del("x-goog-api-key") + req.Header.Del("cookie") + req.Header.Set("x-api-key", token) + + if req.Header.Get("content-type") == "" { + req.Header.Set("content-type", "application/json") + } + if req.Header.Get("anthropic-version") == "" { + req.Header.Set("anthropic-version", "2023-06-01") + } + + return req, nil +} + +func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + startTime time.Time, + model string, +) (*streamingResult, error) { + if s.rateLimitService != nil { + s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) + } + + writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + + contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) + if contentType == "" { + contentType = "text/event-stream" + } + c.Header("Content-Type", contentType) + if c.Writer.Header().Get("Cache-Control") == "" { + c.Header("Cache-Control", "no-cache") + } + if c.Writer.Header().Get("Connection") == "" { + c.Header("Connection", "keep-alive") + } + c.Header("X-Accel-Buffering", "no") + if v := resp.Header.Get("x-request-id"); v != "" { + c.Header("x-request-id", v) + } + + w := c.Writer + flusher, ok := w.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + usage := &ClaudeUsage{} + var firstTokenMs *int + clientDisconnected := false + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) + + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }(scanBuf) + defer close(done) + + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + for { + select { + case ev, ok := <-events: + if !ok { + if !clientDisconnected { + // 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。 + flusher.Flush() + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } + if ev.err != nil { + if clientDisconnected { + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, ev.err) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] 流读取被取消: account=%d request_id=%s err=%v ctx_err=%v", + account.ID, resp.Header.Get("x-request-id"), ev.err, ctx.Err()) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + if errors.Is(ev.err, bufio.ErrTooLong) { + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) + } + + line := ev.line + if data, ok := extractAnthropicSSEDataLine(line); ok { + trimmed := strings.TrimSpace(data) + if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsagePassthrough(data, usage) + } + + if !clientDisconnected { + if _, err := io.WriteString(w, line); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) + } else if _, err := io.WriteString(w, "\n"); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) + } else if line == "" { + // 按 SSE 事件边界刷出,减少每行 flush 带来的 syscall 开销。 + flusher.Flush() + } + } + + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream timeout after client disconnect: account=%d model=%s", account.ID, model) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval) + if s.rateLimitService != nil { + s.rateLimitService.HandleStreamTimeout(ctx, account, model) + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + } + } +} + +func extractAnthropicSSEDataLine(line string) (string, bool) { + if !strings.HasPrefix(line, "data:") { + return "", false + } + start := len("data:") + for start < len(line) { + if line[start] != ' ' && line[start] != '\t' { + break + } + start++ + } + return line[start:], true +} + +func (s *GatewayService) parseSSEUsagePassthrough(data string, usage *ClaudeUsage) { + if usage == nil || data == "" || data == "[DONE]" { + return + } + + parsed := gjson.Parse(data) + switch parsed.Get("type").String() { + case "message_start": + msgUsage := parsed.Get("message.usage") + if msgUsage.Exists() { + usage.InputTokens = int(msgUsage.Get("input_tokens").Int()) + usage.CacheCreationInputTokens = int(msgUsage.Get("cache_creation_input_tokens").Int()) + usage.CacheReadInputTokens = int(msgUsage.Get("cache_read_input_tokens").Int()) + + // 保持与通用解析一致:message_start 允许覆盖 5m/1h 明细(包括 0)。 + cc5m := msgUsage.Get("cache_creation.ephemeral_5m_input_tokens") + cc1h := msgUsage.Get("cache_creation.ephemeral_1h_input_tokens") + if cc5m.Exists() || cc1h.Exists() { + usage.CacheCreation5mTokens = int(cc5m.Int()) + usage.CacheCreation1hTokens = int(cc1h.Int()) + } + } + case "message_delta": + deltaUsage := parsed.Get("usage") + if deltaUsage.Exists() { + if v := deltaUsage.Get("input_tokens").Int(); v > 0 { + usage.InputTokens = int(v) + } + if v := deltaUsage.Get("output_tokens").Int(); v > 0 { + usage.OutputTokens = int(v) + } + if v := deltaUsage.Get("cache_creation_input_tokens").Int(); v > 0 { + usage.CacheCreationInputTokens = int(v) + } + if v := deltaUsage.Get("cache_read_input_tokens").Int(); v > 0 { + usage.CacheReadInputTokens = int(v) + } + + cc5m := deltaUsage.Get("cache_creation.ephemeral_5m_input_tokens") + cc1h := deltaUsage.Get("cache_creation.ephemeral_1h_input_tokens") + if cc5m.Exists() && cc5m.Int() > 0 { + usage.CacheCreation5mTokens = int(cc5m.Int()) + } + if cc1h.Exists() && cc1h.Int() > 0 { + usage.CacheCreation1hTokens = int(cc1h.Int()) + } + } + } + + if usage.CacheReadInputTokens == 0 { + if cached := parsed.Get("message.usage.cached_tokens").Int(); cached > 0 { + usage.CacheReadInputTokens = int(cached) + } + if cached := parsed.Get("usage.cached_tokens").Int(); usage.CacheReadInputTokens == 0 && cached > 0 { + usage.CacheReadInputTokens = int(cached) + } + } + if usage.CacheCreationInputTokens == 0 { + cc5m := parsed.Get("message.usage.cache_creation.ephemeral_5m_input_tokens").Int() + cc1h := parsed.Get("message.usage.cache_creation.ephemeral_1h_input_tokens").Int() + if cc5m == 0 && cc1h == 0 { + cc5m = parsed.Get("usage.cache_creation.ephemeral_5m_input_tokens").Int() + cc1h = parsed.Get("usage.cache_creation.ephemeral_1h_input_tokens").Int() + } + total := cc5m + cc1h + if total > 0 { + usage.CacheCreationInputTokens = int(total) + } + } +} + +func parseClaudeUsageFromResponseBody(body []byte) *ClaudeUsage { + usage := &ClaudeUsage{} + if len(body) == 0 { + return usage + } + + parsed := gjson.ParseBytes(body) + usageNode := parsed.Get("usage") + if !usageNode.Exists() { + return usage + } + + usage.InputTokens = int(usageNode.Get("input_tokens").Int()) + usage.OutputTokens = int(usageNode.Get("output_tokens").Int()) + usage.CacheCreationInputTokens = int(usageNode.Get("cache_creation_input_tokens").Int()) + usage.CacheReadInputTokens = int(usageNode.Get("cache_read_input_tokens").Int()) + + cc5m := usageNode.Get("cache_creation.ephemeral_5m_input_tokens").Int() + cc1h := usageNode.Get("cache_creation.ephemeral_1h_input_tokens").Int() + if cc5m > 0 || cc1h > 0 { + usage.CacheCreation5mTokens = int(cc5m) + usage.CacheCreation1hTokens = int(cc1h) + } + if usage.CacheCreationInputTokens == 0 && (cc5m > 0 || cc1h > 0) { + usage.CacheCreationInputTokens = int(cc5m + cc1h) + } + if usage.CacheReadInputTokens == 0 { + if cached := usageNode.Get("cached_tokens").Int(); cached > 0 { + usage.CacheReadInputTokens = int(cached) + } + } + return usage +} + +func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, +) (*ClaudeUsage, error) { + if s.rateLimitService != nil { + s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) + } + + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } + return nil, err + } + + usage := parseClaudeUsageFromResponseBody(body) + + writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, body) + return usage, nil +} + +func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) { + if dst == nil || src == nil { + return + } + if filter != nil { + responseheaders.WriteFilteredHeaders(dst, src, filter) + return + } + if v := strings.TrimSpace(src.Get("Content-Type")); v != "" { + dst.Set("Content-Type", v) + } + if v := strings.TrimSpace(src.Get("x-request-id")); v != "" { + dst.Set("x-request-id", v) + } +} + func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) { // 确定目标URL targetURL := claudeAPIURL @@ -3484,7 +4969,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // 1. 获取或创建指纹(包含随机生成的ClientID) fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) if err != nil { - log.Printf("Warning: failed to get fingerprint for account %d: %v", account.ID, err) + logger.LegacyPrintf("service.gateway", "Warning: failed to get fingerprint for account %d: %v", account.ID, err) // 失败时降级为透传原始headers } else { fingerprint = fp @@ -3551,12 +5036,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // messages requests typically use only oauth + interleaved-thinking. // Also drop claude-code beta if a downstream client added it. requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} - drop := map[string]struct{}{claude.BetaClaudeCode: {}} - req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop)) + req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, droppedBetasWithClaudeCodeSet)) } else { // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta clientBetaHeader := req.Header.Get("anthropic-beta") - req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, clientBetaHeader)) + req.Header.Set("anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), defaultDroppedBetasSet)) } } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) @@ -3710,6 +5194,64 @@ func mergeAnthropicBetaDropping(required []string, incoming string, drop map[str return strings.Join(out, ",") } +// stripBetaTokens removes the given beta tokens from a comma-separated header value. +func stripBetaTokens(header string, tokens []string) string { + if header == "" || len(tokens) == 0 { + return header + } + return stripBetaTokensWithSet(header, buildBetaTokenSet(tokens)) +} + +func stripBetaTokensWithSet(header string, drop map[string]struct{}) string { + if header == "" || len(drop) == 0 { + return header + } + parts := strings.Split(header, ",") + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if _, ok := drop[p]; ok { + continue + } + out = append(out, p) + } + if len(out) == len(parts) { + return header // no change, avoid allocation + } + return strings.Join(out, ",") +} + +// droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens. +func droppedBetaSet(extra ...string) map[string]struct{} { + m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra)) + for t := range defaultDroppedBetasSet { + m[t] = struct{}{} + } + for _, t := range extra { + m[t] = struct{}{} + } + return m +} + +func buildBetaTokenSet(tokens []string) map[string]struct{} { + m := make(map[string]struct{}, len(tokens)) + for _, t := range tokens { + if t == "" { + continue + } + m[t] = struct{}{} + } + return m +} + +var ( + defaultDroppedBetasSet = buildBetaTokenSet(claude.DroppedBetas) + droppedBetasWithClaudeCodeSet = droppedBetaSet(claude.BetaClaudeCode) +) + // applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers. // This mirrors opencode-anthropic-auth behavior: do not trust downstream // headers when using Claude Code-scoped OAuth credentials. @@ -3756,33 +5298,33 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { } // Log for debugging - log.Printf("[SignatureCheck] Checking error message: %s", msg) + logger.LegacyPrintf("service.gateway", "[SignatureCheck] Checking error message: %s", msg) // 检测signature相关的错误(更宽松的匹配) // 例如: "Invalid `signature` in `thinking` block", "***.signature" 等 if strings.Contains(msg, "signature") { - log.Printf("[SignatureCheck] Detected signature error") + logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected signature error") return true } // 检测 thinking block 顺序/类型错误 // 例如: "Expected `thinking` or `redacted_thinking`, but found `text`" if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) { - log.Printf("[SignatureCheck] Detected thinking block type error") + logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected thinking block type error") return true } // 检测 thinking block 被修改的错误 // 例如: "thinking or redacted_thinking blocks in the latest assistant message cannot be modified" if strings.Contains(msg, "cannot be modified") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) { - log.Printf("[SignatureCheck] Detected thinking block modification error") + logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected thinking block modification error") return true } // 检测空消息内容错误(可能是过滤 thinking blocks 后导致的) // 例如: "all messages must have non-empty content" if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") { - log.Printf("[SignatureCheck] Detected empty content error") + logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected empty content error") return true } @@ -3790,7 +5332,7 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { } func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { - // 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。 + // 只对"可能是兼容性差异导致"的 400 允许切换,避免无意义重试。 // 默认保守:无法识别则不切换。 msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) if msg == "" { @@ -3839,11 +5381,25 @@ func extractUpstreamErrorMessage(body []byte) string { return gjson.GetBytes(body, "message").String() } +func isCountTokensUnsupported404(statusCode int, body []byte) bool { + if statusCode != http.StatusNotFound { + return false + } + msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(body))) + if msg == "" { + return false + } + if strings.Contains(msg, "/v1/messages/count_tokens") { + return true + } + return strings.Contains(msg, "count_tokens") && strings.Contains(msg, "not found") +} + func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) // 调试日志:打印上游错误响应 - log.Printf("[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s", + logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s", account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(body), 1000)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) @@ -3854,7 +5410,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil { if v, ok := c.Get(claudeMimicDebugInfoKey); ok { if line, ok := v.(string); ok && strings.TrimSpace(line) != "" { - log.Printf("[ClaudeMimicDebugOnError] status=%d request_id=%s %s", + logger.LegacyPrintf("service.gateway", "[ClaudeMimicDebugOnError] status=%d request_id=%s %s", resp.StatusCode, resp.Header.Get("x-request-id"), line, @@ -3894,7 +5450,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res // 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端) if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( + logger.LegacyPrintf("service.gateway", "Upstream error %d (account=%d platform=%s type=%s): %s", resp.StatusCode, account.ID, @@ -3995,10 +5551,10 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re // OAuth/Setup Token 账号的 403:标记账号异常 if account.IsOAuth() && statusCode == 403 { s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body) - log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode) + logger.LegacyPrintf("service.gateway", "Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode) } else { // API Key 未配置错误码:不标记账号状态 - log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts) + logger.LegacyPrintf("service.gateway", "Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts) } } @@ -4024,7 +5580,7 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil { if v, ok := c.Get(claudeMimicDebugInfoKey); ok { if line, ok := v.(string); ok && strings.TrimSpace(line) != "" { - log.Printf("[ClaudeMimicDebugOnError] status=%d request_id=%s %s", + logger.LegacyPrintf("service.gateway", "[ClaudeMimicDebugOnError] status=%d request_id=%s %s", resp.StatusCode, resp.Header.Get("x-request-id"), line, @@ -4053,7 +5609,7 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht }) if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( + logger.LegacyPrintf("service.gateway", "Upstream error %d retries_exhausted (account=%d platform=%s type=%s): %s", resp.StatusCode, account.ID, @@ -4116,8 +5672,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) - if s.cfg != nil { - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) } // 设置SSE响应头 @@ -4145,7 +5701,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) type scanEvent struct { line string @@ -4164,7 +5721,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -4175,7 +5733,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) streamInterval := time.Duration(0) @@ -4209,9 +5767,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http pendingEventLines := make([]string, 0, 4) - processSSEEvent := func(lines []string) ([]string, string, error) { + processSSEEvent := func(lines []string) ([]string, string, *sseUsagePatch, error) { if len(lines) == 0 { - return nil, "", nil + return nil, "", nil, nil } eventName := "" @@ -4228,11 +5786,11 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } if eventName == "error" { - return nil, dataLine, errors.New("have error in stream") + return nil, dataLine, nil, errors.New("have error in stream") } if dataLine == "" { - return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil + return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil, nil } if dataLine == "[DONE]" { @@ -4241,7 +5799,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http block = "event: " + eventName + "\n" } block += "data: " + dataLine + "\n\n" - return []string{block}, dataLine, nil + return []string{block}, dataLine, nil, nil } var event map[string]any @@ -4252,25 +5810,43 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http block = "event: " + eventName + "\n" } block += "data: " + dataLine + "\n\n" - return []string{block}, dataLine, nil + return []string{block}, dataLine, nil, nil } eventType, _ := event["type"].(string) if eventName == "" { eventName = eventType } + eventChanged := false // 兼容 Kimi cached_tokens → cache_read_input_tokens if eventType == "message_start" { if msg, ok := event["message"].(map[string]any); ok { if u, ok := msg["usage"].(map[string]any); ok { - reconcileCachedTokens(u) + eventChanged = reconcileCachedTokens(u) || eventChanged } } } if eventType == "message_delta" { if u, ok := event["usage"].(map[string]any); ok { - reconcileCachedTokens(u) + eventChanged = reconcileCachedTokens(u) || eventChanged + } + } + + // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类 + if account.IsCacheTTLOverrideEnabled() { + overrideTarget := account.GetCacheTTLOverrideTarget() + if eventType == "message_start" { + if msg, ok := event["message"].(map[string]any); ok { + if u, ok := msg["usage"].(map[string]any); ok { + eventChanged = rewriteCacheCreationJSON(u, overrideTarget) || eventChanged + } + } + } + if eventType == "message_delta" { + if u, ok := event["usage"].(map[string]any); ok { + eventChanged = rewriteCacheCreationJSON(u, overrideTarget) || eventChanged + } } } @@ -4278,10 +5854,21 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if msg, ok := event["message"].(map[string]any); ok { if model, ok := msg["model"].(string); ok && model == mappedModel { msg["model"] = originalModel + eventChanged = true } } } + usagePatch := s.extractSSEUsagePatch(event) + if !eventChanged { + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + dataLine + "\n\n" + return []string{block}, dataLine, usagePatch, nil + } + newData, err := json.Marshal(event) if err != nil { // 序列化失败,直接透传原始数据 @@ -4290,7 +5877,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http block = "event: " + eventName + "\n" } block += "data: " + dataLine + "\n\n" - return []string{block}, dataLine, nil + return []string{block}, dataLine, usagePatch, nil } block := "" @@ -4298,7 +5885,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http block = "event: " + eventName + "\n" } block += "data: " + string(newData) + "\n\n" - return []string{block}, string(newData), nil + return []string{block}, string(newData), usagePatch, nil } for { @@ -4311,17 +5898,17 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if ev.err != nil { // 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取) if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { - log.Printf("Context canceled during streaming, returning collected usage") + logger.LegacyPrintf("service.gateway", "Context canceled during streaming, returning collected usage") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil } // 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage if clientDisconnected { - log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err) + logger.LegacyPrintf("service.gateway", "Upstream read error after client disconnect: %v, returning collected usage", ev.err) return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil } // 客户端未断开,正常的错误处理 if errors.Is(ev.err, bufio.ErrTooLong) { - log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) + logger.LegacyPrintf("service.gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) sendErrorEvent("response_too_large") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err } @@ -4336,7 +5923,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http continue } - outputBlocks, data, err := processSSEEvent(pendingEventLines) + outputBlocks, data, usagePatch, err := processSSEEvent(pendingEventLines) pendingEventLines = pendingEventLines[:0] if err != nil { if clientDisconnected { @@ -4349,7 +5936,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if !clientDisconnected { if _, werr := fmt.Fprint(w, block); werr != nil { clientDisconnected = true - log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") + logger.LegacyPrintf("service.gateway", "Client disconnected during streaming, continuing to drain upstream for billing") break } flusher.Flush() @@ -4359,7 +5946,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } - s.parseSSEUsage(data, usage) + if usagePatch != nil { + mergeSSEUsagePatch(usage, usagePatch) + } } } continue @@ -4374,10 +5963,10 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } if clientDisconnected { // 客户端已断开,上游也超时了,返回已收集的 usage - log.Printf("Upstream timeout after client disconnect, returning collected usage") + logger.LegacyPrintf("service.gateway", "Upstream timeout after client disconnect, returning collected usage") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil } - log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) + logger.LegacyPrintf("service.gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) // 处理流超时,可能标记账户为临时不可调度或错误状态 if s.rateLimitService != nil { s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) @@ -4390,54 +5979,241 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { - // 解析message_start获取input tokens(标准Claude API格式) - var msgStart struct { - Type string `json:"type"` - Message struct { - Usage ClaudeUsage `json:"usage"` - } `json:"message"` - } - if json.Unmarshal([]byte(data), &msgStart) == nil && msgStart.Type == "message_start" { - usage.InputTokens = msgStart.Message.Usage.InputTokens - usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens - usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens + if usage == nil { + return } - // 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API) - var msgDelta struct { - Type string `json:"type"` - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - CacheCreationInputTokens int `json:"cache_creation_input_tokens"` - CacheReadInputTokens int `json:"cache_read_input_tokens"` - } `json:"usage"` + var event map[string]any + if err := json.Unmarshal([]byte(data), &event); err != nil { + return } - if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" { - // message_delta 仅覆盖存在且非0的字段 - // 避免覆盖 message_start 中已有的值(如 input_tokens) - // Claude API 的 message_delta 通常只包含 output_tokens - if msgDelta.Usage.InputTokens > 0 { - usage.InputTokens = msgDelta.Usage.InputTokens + + if patch := s.extractSSEUsagePatch(event); patch != nil { + mergeSSEUsagePatch(usage, patch) + } +} + +type sseUsagePatch struct { + inputTokens int + hasInputTokens bool + outputTokens int + hasOutputTokens bool + cacheCreationInputTokens int + hasCacheCreationInput bool + cacheReadInputTokens int + hasCacheReadInput bool + cacheCreation5mTokens int + hasCacheCreation5m bool + cacheCreation1hTokens int + hasCacheCreation1h bool +} + +func (s *GatewayService) extractSSEUsagePatch(event map[string]any) *sseUsagePatch { + if len(event) == 0 { + return nil + } + + eventType, _ := event["type"].(string) + switch eventType { + case "message_start": + msg, _ := event["message"].(map[string]any) + usageObj, _ := msg["usage"].(map[string]any) + if len(usageObj) == 0 { + return nil } - if msgDelta.Usage.OutputTokens > 0 { - usage.OutputTokens = msgDelta.Usage.OutputTokens + + patch := &sseUsagePatch{} + patch.hasInputTokens = true + if v, ok := parseSSEUsageInt(usageObj["input_tokens"]); ok { + patch.inputTokens = v } - if msgDelta.Usage.CacheCreationInputTokens > 0 { - usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens + patch.hasCacheCreationInput = true + if v, ok := parseSSEUsageInt(usageObj["cache_creation_input_tokens"]); ok { + patch.cacheCreationInputTokens = v } - if msgDelta.Usage.CacheReadInputTokens > 0 { - usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens + patch.hasCacheReadInput = true + if v, ok := parseSSEUsageInt(usageObj["cache_read_input_tokens"]); ok { + patch.cacheReadInputTokens = v + } + if cc, ok := usageObj["cache_creation"].(map[string]any); ok { + if v, exists := parseSSEUsageInt(cc["ephemeral_5m_input_tokens"]); exists { + patch.cacheCreation5mTokens = v + patch.hasCacheCreation5m = true + } + if v, exists := parseSSEUsageInt(cc["ephemeral_1h_input_tokens"]); exists { + patch.cacheCreation1hTokens = v + patch.hasCacheCreation1h = true + } + } + return patch + + case "message_delta": + usageObj, _ := event["usage"].(map[string]any) + if len(usageObj) == 0 { + return nil + } + + patch := &sseUsagePatch{} + if v, ok := parseSSEUsageInt(usageObj["input_tokens"]); ok && v > 0 { + patch.inputTokens = v + patch.hasInputTokens = true + } + if v, ok := parseSSEUsageInt(usageObj["output_tokens"]); ok && v > 0 { + patch.outputTokens = v + patch.hasOutputTokens = true + } + if v, ok := parseSSEUsageInt(usageObj["cache_creation_input_tokens"]); ok && v > 0 { + patch.cacheCreationInputTokens = v + patch.hasCacheCreationInput = true + } + if v, ok := parseSSEUsageInt(usageObj["cache_read_input_tokens"]); ok && v > 0 { + patch.cacheReadInputTokens = v + patch.hasCacheReadInput = true + } + if cc, ok := usageObj["cache_creation"].(map[string]any); ok { + if v, exists := parseSSEUsageInt(cc["ephemeral_5m_input_tokens"]); exists && v > 0 { + patch.cacheCreation5mTokens = v + patch.hasCacheCreation5m = true + } + if v, exists := parseSSEUsageInt(cc["ephemeral_1h_input_tokens"]); exists && v > 0 { + patch.cacheCreation1hTokens = v + patch.hasCacheCreation1h = true + } + } + return patch + } + + return nil +} + +func mergeSSEUsagePatch(usage *ClaudeUsage, patch *sseUsagePatch) { + if usage == nil || patch == nil { + return + } + + if patch.hasInputTokens { + usage.InputTokens = patch.inputTokens + } + if patch.hasCacheCreationInput { + usage.CacheCreationInputTokens = patch.cacheCreationInputTokens + } + if patch.hasCacheReadInput { + usage.CacheReadInputTokens = patch.cacheReadInputTokens + } + if patch.hasOutputTokens { + usage.OutputTokens = patch.outputTokens + } + if patch.hasCacheCreation5m { + usage.CacheCreation5mTokens = patch.cacheCreation5mTokens + } + if patch.hasCacheCreation1h { + usage.CacheCreation1hTokens = patch.cacheCreation1hTokens + } +} + +func parseSSEUsageInt(value any) (int, bool) { + switch v := value.(type) { + case float64: + return int(v), true + case float32: + return int(v), true + case int: + return v, true + case int64: + return int(v), true + case int32: + return int(v), true + case json.Number: + if i, err := v.Int64(); err == nil { + return int(i), true + } + if f, err := v.Float64(); err == nil { + return int(f), true + } + case string: + if parsed, err := strconv.Atoi(strings.TrimSpace(v)); err == nil { + return parsed, true } } + return 0, false +} + +// applyCacheTTLOverride 将所有 cache creation tokens 归入指定的 TTL 类型。 +// target 为 "5m" 或 "1h"。返回 true 表示发生了变更。 +func applyCacheTTLOverride(usage *ClaudeUsage, target string) bool { + // Fallback: 如果只有聚合字段但无 5m/1h 明细,将聚合字段归入 5m 默认类别 + if usage.CacheCreation5mTokens == 0 && usage.CacheCreation1hTokens == 0 && usage.CacheCreationInputTokens > 0 { + usage.CacheCreation5mTokens = usage.CacheCreationInputTokens + } + + total := usage.CacheCreation5mTokens + usage.CacheCreation1hTokens + if total == 0 { + return false + } + switch target { + case "1h": + if usage.CacheCreation1hTokens == total { + return false // 已经全是 1h + } + usage.CacheCreation1hTokens = total + usage.CacheCreation5mTokens = 0 + default: // "5m" + if usage.CacheCreation5mTokens == total { + return false // 已经全是 5m + } + usage.CacheCreation5mTokens = total + usage.CacheCreation1hTokens = 0 + } + return true +} + +// rewriteCacheCreationJSON 在 JSON usage 对象中重写 cache_creation 嵌套对象的 TTL 分类。 +// usageObj 是 usage JSON 对象(map[string]any)。 +func rewriteCacheCreationJSON(usageObj map[string]any, target string) bool { + ccObj, ok := usageObj["cache_creation"].(map[string]any) + if !ok { + return false + } + v5m, _ := parseSSEUsageInt(ccObj["ephemeral_5m_input_tokens"]) + v1h, _ := parseSSEUsageInt(ccObj["ephemeral_1h_input_tokens"]) + total := v5m + v1h + if total == 0 { + return false + } + switch target { + case "1h": + if v1h == total { + return false + } + ccObj["ephemeral_1h_input_tokens"] = float64(total) + ccObj["ephemeral_5m_input_tokens"] = float64(0) + default: // "5m" + if v5m == total { + return false + } + ccObj["ephemeral_5m_input_tokens"] = float64(total) + ccObj["ephemeral_1h_input_tokens"] = float64(0) + } + return true } func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) { // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) - body, err := io.ReadAll(resp.Body) + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } return nil, err } @@ -4449,6 +6225,14 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h return nil, fmt.Errorf("parse response: %w", err) } + // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 + cc5m := gjson.GetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens") + cc1h := gjson.GetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens") + if cc5m.Exists() || cc1h.Exists() { + response.Usage.CacheCreation5mTokens = int(cc5m.Int()) + response.Usage.CacheCreation1hTokens = int(cc1h.Int()) + } + // 兼容 Kimi cached_tokens → cache_read_input_tokens if response.Usage.CacheReadInputTokens == 0 { cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int() @@ -4460,12 +6244,26 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h } } + // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类 + if account.IsCacheTTLOverrideEnabled() { + overrideTarget := account.GetCacheTTLOverrideTarget() + if applyCacheTTLOverride(&response.Usage, overrideTarget) { + // 同步更新 body JSON 中的嵌套 cache_creation 对象 + if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens", response.Usage.CacheCreation5mTokens); err == nil { + body = newBody + } + if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens", response.Usage.CacheCreation1hTokens); err == nil { + body = newBody + } + } + } + // 如果有模型映射,替换响应中的model字段 if originalModel != mappedModel { body = s.replaceModelInResponseBody(body, mappedModel, originalModel) } - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := "application/json" if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled { @@ -4481,24 +6279,76 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h } // replaceModelInResponseBody 替换响应体中的model字段 +// 使用 gjson/sjson 精确替换,避免全量 JSON 反序列化 func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { - var resp map[string]any - if err := json.Unmarshal(body, &resp); err != nil { - return body + if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel { + newBody, err := sjson.SetBytes(body, "model", toModel) + if err != nil { + return body + } + return newBody + } + return body +} + +func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 { + if s == nil || userID <= 0 || groupID <= 0 { + return groupDefaultMultiplier } - model, ok := resp["model"].(string) - if !ok || model != fromModel { - return body + key := fmt.Sprintf("%d:%d", userID, groupID) + if s.userGroupRateCache != nil { + if cached, ok := s.userGroupRateCache.Get(key); ok { + if multiplier, castOK := cached.(float64); castOK { + userGroupRateCacheHitTotal.Add(1) + return multiplier + } + } } + if s.userGroupRateRepo == nil { + return groupDefaultMultiplier + } + userGroupRateCacheMissTotal.Add(1) - resp["model"] = toModel - newBody, err := json.Marshal(resp) + value, err, shared := s.userGroupRateSF.Do(key, func() (any, error) { + if s.userGroupRateCache != nil { + if cached, ok := s.userGroupRateCache.Get(key); ok { + if multiplier, castOK := cached.(float64); castOK { + userGroupRateCacheHitTotal.Add(1) + return multiplier, nil + } + } + } + + userGroupRateCacheLoadTotal.Add(1) + userRate, repoErr := s.userGroupRateRepo.GetByUserAndGroup(ctx, userID, groupID) + if repoErr != nil { + return nil, repoErr + } + multiplier := groupDefaultMultiplier + if userRate != nil { + multiplier = *userRate + } + if s.userGroupRateCache != nil { + s.userGroupRateCache.Set(key, multiplier, resolveUserGroupRateCacheTTL(s.cfg)) + } + return multiplier, nil + }) + if shared { + userGroupRateCacheSFSharedTotal.Add(1) + } if err != nil { - return body + userGroupRateCacheFallbackTotal.Add(1) + logger.LegacyPrintf("service.gateway", "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err) + return groupDefaultMultiplier } - return newBody + multiplier, ok := value.(float64) + if !ok { + userGroupRateCacheFallbackTotal.Add(1) + return groupDefaultMultiplier + } + return multiplier } // RecordUsageInput 记录使用量的输入参数 @@ -4514,9 +6364,10 @@ type RecordUsageInput struct { APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 } -// APIKeyQuotaUpdater defines the interface for updating API Key quota +// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage type APIKeyQuotaUpdater interface { UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error + UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error } // RecordUsage 记录使用量并扣费(或更新订阅用量) @@ -4530,29 +6381,50 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu // 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens // 用于粘性会话切换时的特殊计费处理 if input.ForceCacheBilling && result.Usage.InputTokens > 0 { - log.Printf("force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", + logger.LegacyPrintf("service.gateway", "force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", result.Usage.InputTokens, account.ID) result.Usage.CacheReadInputTokens += result.Usage.InputTokens result.Usage.InputTokens = 0 } - // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) - multiplier := s.cfg.Default.RateMultiplier - if apiKey.GroupID != nil && apiKey.Group != nil { - multiplier = apiKey.Group.RateMultiplier + // Cache TTL Override: 确保计费时 token 分类与账号设置一致 + cacheTTLOverridden := false + if account.IsCacheTTLOverrideEnabled() { + applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) + cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 + } - // 检查用户专属倍率 - if s.userGroupRateRepo != nil { - if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil { - multiplier = *userRate - } - } + // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) + multiplier := 1.0 + if s.cfg != nil { + multiplier = s.cfg.Default.RateMultiplier + } + if apiKey.GroupID != nil && apiKey.Group != nil { + groupDefault := apiKey.Group.RateMultiplier + multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) } var cost *CostBreakdown // 根据请求类型选择计费方式 - if result.ImageCount > 0 { + if result.MediaType == "image" || result.MediaType == "video" { + var soraConfig *SoraPriceConfig + if apiKey.Group != nil { + soraConfig = &SoraPriceConfig{ + ImagePrice360: apiKey.Group.SoraImagePrice360, + ImagePrice540: apiKey.Group.SoraImagePrice540, + VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, + VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, + } + } + if result.MediaType == "image" { + cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) + } else { + cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier) + } + } else if result.MediaType == "prompt" { + cost = &CostBreakdown{} + } else if result.ImageCount > 0 { // 图片生成计费 var groupConfig *ImagePriceConfig if apiKey.Group != nil { @@ -4566,15 +6438,17 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } else { // Token 计费 tokens := UsageTokens{ - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, } var err error cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier) if err != nil { - log.Printf("Calculate cost failed: %v", err) + logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) cost = &CostBreakdown{ActualCost: 0} } } @@ -4592,6 +6466,10 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu if result.ImageSize != "" { imageSize = &result.ImageSize } + var mediaType *string + if strings.TrimSpace(result.MediaType) != "" { + mediaType = &result.MediaType + } accountRateMultiplier := account.BillingRateMultiplier() usageLog := &UsageLog{ UserID: user.ID, @@ -4603,6 +6481,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, InputCost: cost.InputCost, OutputCost: cost.OutputCost, CacheCreationCost: cost.CacheCreationCost, @@ -4617,6 +6497,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu FirstTokenMs: result.FirstTokenMs, ImageCount: result.ImageCount, ImageSize: imageSize, + MediaType: mediaType, + CacheTTLOverridden: cacheTTLOverridden, CreatedAt: time.Now(), } @@ -4640,11 +6522,11 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu inserted, err := s.usageLogRepo.Create(ctx, usageLog) if err != nil { - log.Printf("Create usage log failed: %v", err) + logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err) } if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } @@ -4656,7 +6538,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) if shouldBill && cost.TotalCost > 0 { if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { - log.Printf("Increment subscription usage failed: %v", err) + logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err) } // 异步更新订阅缓存 s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) @@ -4665,7 +6547,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) if shouldBill && cost.ActualCost > 0 { if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { - log.Printf("Deduct balance failed: %v", err) + logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err) } // 异步更新余额缓存 s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) @@ -4675,10 +6557,18 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu // 更新 API Key 配额(如果设置了配额限制) if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { - log.Printf("Update API key quota failed: %v", err) + logger.LegacyPrintf("service.gateway", "Update API key quota failed: %v", err) } } + // Update API Key rate limit usage + if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil { + if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil { + logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err) + } + s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost) + } + // Schedule batch update for account last_used_at s.deferredService.ScheduleLastUsedUpdate(account.ID) @@ -4711,23 +6601,27 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * // 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens // 用于粘性会话切换时的特殊计费处理 if input.ForceCacheBilling && result.Usage.InputTokens > 0 { - log.Printf("force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", + logger.LegacyPrintf("service.gateway", "force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", result.Usage.InputTokens, account.ID) result.Usage.CacheReadInputTokens += result.Usage.InputTokens result.Usage.InputTokens = 0 } - // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) - multiplier := s.cfg.Default.RateMultiplier - if apiKey.GroupID != nil && apiKey.Group != nil { - multiplier = apiKey.Group.RateMultiplier + // Cache TTL Override: 确保计费时 token 分类与账号设置一致 + cacheTTLOverridden := false + if account.IsCacheTTLOverrideEnabled() { + applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) + cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 + } - // 检查用户专属倍率 - if s.userGroupRateRepo != nil { - if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil { - multiplier = *userRate - } - } + // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) + multiplier := 1.0 + if s.cfg != nil { + multiplier = s.cfg.Default.RateMultiplier + } + if apiKey.GroupID != nil && apiKey.Group != nil { + groupDefault := apiKey.Group.RateMultiplier + multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) } var cost *CostBreakdown @@ -4747,15 +6641,17 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * } else { // Token 计费(使用长上下文计费方法) tokens := UsageTokens{ - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, } var err error cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) if err != nil { - log.Printf("Calculate cost failed: %v", err) + logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) cost = &CostBreakdown{ActualCost: 0} } } @@ -4784,6 +6680,8 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, InputCost: cost.InputCost, OutputCost: cost.OutputCost, CacheCreationCost: cost.CacheCreationCost, @@ -4798,6 +6696,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * FirstTokenMs: result.FirstTokenMs, ImageCount: result.ImageCount, ImageSize: imageSize, + CacheTTLOverridden: cacheTTLOverridden, CreatedAt: time.Now(), } @@ -4821,11 +6720,11 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * inserted, err := s.usageLogRepo.Create(ctx, usageLog) if err != nil { - log.Printf("Create usage log failed: %v", err) + logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err) } if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } @@ -4837,7 +6736,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) if shouldBill && cost.TotalCost > 0 { if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { - log.Printf("Increment subscription usage failed: %v", err) + logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err) } // 异步更新订阅缓存 s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) @@ -4846,19 +6745,27 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) if shouldBill && cost.ActualCost > 0 { if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { - log.Printf("Deduct balance failed: %v", err) + logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err) } // 异步更新余额缓存 s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) // API Key 独立配额扣费 if input.APIKeyService != nil && apiKey.Quota > 0 { if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { - log.Printf("Add API key quota used failed: %v", err) + logger.LegacyPrintf("service.gateway", "Add API key quota used failed: %v", err) } } } } + // Update API Key rate limit usage + if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil { + if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil { + logger.LegacyPrintf("service.gateway", "Update API key rate limit usage failed: %v", err) + } + s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost) + } + // Schedule batch update for account last_used_at s.deferredService.ScheduleLastUsedUpdate(account.ID) @@ -4873,6 +6780,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, return fmt.Errorf("parse request: empty request") } + if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { + return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body) + } + body := parsed.Body reqModel := parsed.Model @@ -4884,9 +6795,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) } - // Antigravity 账户不支持 count_tokens 转发,直接返回空值 + // Antigravity 账户不支持 count_tokens,返回 404 让客户端 fallback 到本地估算。 + // 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。 if account.Platform == PlatformAntigravity { - c.JSON(http.StatusOK, gin.H{"input_tokens": 0}) + s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for this platform") return nil } @@ -4912,7 +6824,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, if mappedModel != reqModel { body = s.replaceModelInBody(body, mappedModel) reqModel = mappedModel - log.Printf("CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource) + logger.LegacyPrintf("service.gateway", "CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource) } } @@ -4945,16 +6857,22 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // 读取响应体 - respBody, err := io.ReadAll(resp.Body) + maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg) + respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) _ = resp.Body.Close() if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + return err + } s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") return err } // 检测 thinking block 签名错误(400)并重试一次(过滤 thinking blocks) if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) { - log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID) + logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID) filteredBody := FilterThinkingBlocksForRetry(body) retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode) @@ -4962,9 +6880,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { resp = retryResp - respBody, err = io.ReadAll(resp.Body) + respBody, err = readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) _ = resp.Body.Close() if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + return err + } s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") return err } @@ -4991,7 +6914,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, // 记录上游错误摘要便于排障(不回显请求内容) if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( + logger.LegacyPrintf("service.gateway", "count_tokens upstream error %d (account=%d platform=%s type=%s): %s", resp.StatusCode, account.ID, @@ -5021,6 +6944,170 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, return nil } +func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx context.Context, c *gin.Context, account *Account, body []byte) error { + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token") + return err + } + if tokenType != "apikey" { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Invalid account token type") + return fmt.Errorf("anthropic api key passthrough requires apikey token, got: %s", tokenType) + } + + upstreamReq, err := s.buildCountTokensRequestAnthropicAPIKeyPassthrough(ctx, c, account, body, token) + if err != nil { + s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") + return err + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if err != nil { + setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Passthrough: true, + Kind: "request_error", + Message: sanitizeUpstreamErrorMessage(err.Error()), + }) + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") + return fmt.Errorf("upstream request failed: %w", err) + } + + maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg) + respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) + _ = resp.Body.Close() + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + return err + } + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") + return err + } + + if resp.StatusCode >= 400 { + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + // 中转站不支持 count_tokens 端点时(404),返回 404 让客户端 fallback 到本地估算。 + // 仅在错误消息明确指向 count_tokens endpoint 不存在时生效,避免误吞其他 404(如错误 base_url)。 + // 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。 + if isCountTokensUnsupported404(resp.StatusCode, respBody) { + logger.LegacyPrintf("service.gateway", + "[count_tokens] Upstream does not support count_tokens (404), returning 404: account=%d name=%s msg=%s", + account.ID, account.Name, truncateString(upstreamMsg, 512)) + s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported by upstream") + return nil + } + + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + errMsg := "Upstream request failed" + switch resp.StatusCode { + case 429: + errMsg = "Rate limit exceeded" + case 529: + errMsg = "Service overloaded" + } + s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg) + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) + } + + writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, respBody) + return nil +} + +func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + token string, +) (*http.Request, error) { + targetURL := claudeAPICountTokensURL + baseURL := account.GetBaseURL() + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = validatedURL + "/v1/messages/count_tokens" + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(strings.TrimSpace(key)) + if !allowedHeaders[lowerKey] { + continue + } + for _, v := range values { + req.Header.Add(key, v) + } + } + } + + req.Header.Del("authorization") + req.Header.Del("x-api-key") + req.Header.Del("x-goog-api-key") + req.Header.Del("cookie") + req.Header.Set("x-api-key", token) + + if req.Header.Get("content-type") == "" { + req.Header.Set("content-type", "application/json") + } + if req.Header.Get("anthropic-version") == "" { + req.Header.Set("anthropic-version", "2023-06-01") + } + + return req, nil +} + // buildCountTokensRequest 构建 count_tokens 上游请求 func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, mimicClaudeCode bool) (*http.Request, error) { // 确定目标 URL @@ -5103,7 +7190,8 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con incomingBeta := req.Header.Get("anthropic-beta") requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting} - req.Header.Set("anthropic-beta", mergeAnthropicBeta(requiredBetas, incomingBeta)) + drop := droppedBetaSet() + req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop)) } else { clientBetaHeader := req.Header.Get("anthropic-beta") if clientBetaHeader == "" { @@ -5113,7 +7201,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con if !strings.Contains(beta, claude.BetaTokenCounting) { beta = beta + "," + claude.BetaTokenCounting } - req.Header.Set("anthropic-beta", beta) + req.Header.Set("anthropic-beta", stripBetaTokensWithSet(beta, defaultDroppedBetasSet)) } } } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { @@ -5168,6 +7256,17 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { // GetAvailableModels returns the list of models available for a group // It aggregates model_mapping keys from all schedulable accounts in the group func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string { + cacheKey := modelsListCacheKey(groupID, platform) + if s.modelsListCache != nil { + if cached, found := s.modelsListCache.Get(cacheKey); found { + if models, ok := cached.([]string); ok { + modelsListCacheHitTotal.Add(1) + return cloneStringSlice(models) + } + } + } + modelsListCacheMissTotal.Add(1) + var accounts []Account var err error @@ -5208,6 +7307,10 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, // If no account has model_mapping, return nil (use default) if !hasAnyMapping { + if s.modelsListCache != nil { + s.modelsListCache.Set(cacheKey, []string(nil), s.modelsListCacheTTL) + modelsListCacheStoreTotal.Add(1) + } return nil } @@ -5216,8 +7319,45 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, for model := range modelSet { models = append(models, model) } + sort.Strings(models) - return models + if s.modelsListCache != nil { + s.modelsListCache.Set(cacheKey, cloneStringSlice(models), s.modelsListCacheTTL) + modelsListCacheStoreTotal.Add(1) + } + return cloneStringSlice(models) +} + +func (s *GatewayService) InvalidateAvailableModelsCache(groupID *int64, platform string) { + if s == nil || s.modelsListCache == nil { + return + } + + normalizedPlatform := strings.TrimSpace(platform) + // 完整匹配时精准失效;否则按维度批量失效。 + if groupID != nil && normalizedPlatform != "" { + s.modelsListCache.Delete(modelsListCacheKey(groupID, normalizedPlatform)) + return + } + + targetGroup := derefGroupID(groupID) + for key := range s.modelsListCache.Items() { + parts := strings.SplitN(key, "|", 2) + if len(parts) != 2 { + continue + } + groupPart, parseErr := strconv.ParseInt(parts[0], 10, 64) + if parseErr != nil { + continue + } + if groupID != nil && groupPart != targetGroup { + continue + } + if normalizedPlatform != "" && parts[1] != normalizedPlatform { + continue + } + s.modelsListCache.Delete(key) + } } // reconcileCachedTokens 兼容 Kimi 等上游: diff --git a/backend/internal/service/gateway_service_selection_failure_stats_test.go b/backend/internal/service/gateway_service_selection_failure_stats_test.go new file mode 100644 index 00000000..743d70bb --- /dev/null +++ b/backend/internal/service/gateway_service_selection_failure_stats_test.go @@ -0,0 +1,141 @@ +package service + +import ( + "context" + "strings" + "testing" + "time" +) + +func TestCollectSelectionFailureStats(t *testing.T) { + svc := &GatewayService{} + model := "sora2-landscape-10s" + resetAt := time.Now().Add(2 * time.Minute).Format(time.RFC3339) + + accounts := []Account{ + // excluded + { + ID: 1, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + }, + // unschedulable + { + ID: 2, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: false, + }, + // platform filtered + { + ID: 3, + Platform: PlatformOpenAI, + Status: StatusActive, + Schedulable: true, + }, + // model unsupported + { + ID: 4, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-image": "gpt-image", + }, + }, + }, + // model rate limited + { + ID: 5, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + Extra: map[string]any{ + "model_rate_limits": map[string]any{ + model: map[string]any{ + "rate_limit_reset_at": resetAt, + }, + }, + }, + }, + // eligible + { + ID: 6, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + }, + } + + excluded := map[int64]struct{}{1: {}} + stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformSora, excluded, false) + + if stats.Total != 6 { + t.Fatalf("total=%d want=6", stats.Total) + } + if stats.Excluded != 1 { + t.Fatalf("excluded=%d want=1", stats.Excluded) + } + if stats.Unschedulable != 1 { + t.Fatalf("unschedulable=%d want=1", stats.Unschedulable) + } + if stats.PlatformFiltered != 1 { + t.Fatalf("platform_filtered=%d want=1", stats.PlatformFiltered) + } + if stats.ModelUnsupported != 1 { + t.Fatalf("model_unsupported=%d want=1", stats.ModelUnsupported) + } + if stats.ModelRateLimited != 1 { + t.Fatalf("model_rate_limited=%d want=1", stats.ModelRateLimited) + } + if stats.Eligible != 1 { + t.Fatalf("eligible=%d want=1", stats.Eligible) + } +} + +func TestDiagnoseSelectionFailure_SoraUnschedulableDetail(t *testing.T) { + svc := &GatewayService{} + acc := &Account{ + ID: 7, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: false, + } + + diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false) + if diagnosis.Category != "unschedulable" { + t.Fatalf("category=%s want=unschedulable", diagnosis.Category) + } + if diagnosis.Detail != "schedulable=false" { + t.Fatalf("detail=%s want=schedulable=false", diagnosis.Detail) + } +} + +func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) { + svc := &GatewayService{} + model := "sora2-landscape-10s" + resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339) + acc := &Account{ + ID: 8, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + Extra: map[string]any{ + "model_rate_limits": map[string]any{ + model: map[string]any{ + "rate_limit_reset_at": resetAt, + }, + }, + }, + } + + diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformSora, map[int64]struct{}{}, false) + if diagnosis.Category != "model_rate_limited" { + t.Fatalf("category=%s want=model_rate_limited", diagnosis.Category) + } + if !strings.Contains(diagnosis.Detail, "remaining=") { + t.Fatalf("detail=%s want contains remaining=", diagnosis.Detail) + } +} diff --git a/backend/internal/service/gateway_service_sora_model_support_test.go b/backend/internal/service/gateway_service_sora_model_support_test.go new file mode 100644 index 00000000..8ee2a960 --- /dev/null +++ b/backend/internal/service/gateway_service_sora_model_support_test.go @@ -0,0 +1,79 @@ +package service + +import "testing" + +func TestGatewayServiceIsModelSupportedByAccount_SoraNoMappingAllowsAll(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformSora, + Credentials: map[string]any{}, + } + + if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { + t.Fatalf("expected sora model to be supported when model_mapping is empty") + } +} + +func TestGatewayServiceIsModelSupportedByAccount_SoraLegacyNonSoraMappingDoesNotBlock(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformSora, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-4o": "gpt-4o", + }, + }, + } + + if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { + t.Fatalf("expected sora model to be supported when mapping has no sora selectors") + } +} + +func TestGatewayServiceIsModelSupportedByAccount_SoraFamilyAlias(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformSora, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "sora2": "sora2", + }, + }, + } + + if !svc.isModelSupportedByAccount(account, "sora2-landscape-15s") { + t.Fatalf("expected family selector sora2 to support sora2-landscape-15s") + } +} + +func TestGatewayServiceIsModelSupportedByAccount_SoraUnderlyingModelAlias(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformSora, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "sy_8": "sy_8", + }, + }, + } + + if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { + t.Fatalf("expected underlying model selector sy_8 to support sora2-landscape-10s") + } +} + +func TestGatewayServiceIsModelSupportedByAccount_SoraExplicitImageSelectorBlocksVideo(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformSora, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-image": "gpt-image", + }, + }, + } + + if svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { + t.Fatalf("expected video model to be blocked when mapping explicitly only allows gpt-image") + } +} diff --git a/backend/internal/service/gateway_service_sora_scheduling_test.go b/backend/internal/service/gateway_service_sora_scheduling_test.go new file mode 100644 index 00000000..5178e68e --- /dev/null +++ b/backend/internal/service/gateway_service_sora_scheduling_test.go @@ -0,0 +1,89 @@ +package service + +import ( + "context" + "testing" + "time" +) + +func TestGatewayServiceIsAccountSchedulableForSelectionSoraIgnoresGenericWindows(t *testing.T) { + svc := &GatewayService{} + now := time.Now() + past := now.Add(-1 * time.Minute) + future := now.Add(5 * time.Minute) + + acc := &Account{ + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + AutoPauseOnExpired: true, + ExpiresAt: &past, + OverloadUntil: &future, + RateLimitResetAt: &future, + } + + if !svc.isAccountSchedulableForSelection(acc) { + t.Fatalf("expected sora account to ignore generic expiry/overload/rate-limit windows") + } +} + +func TestGatewayServiceIsAccountSchedulableForSelectionNonSoraKeepsGenericLogic(t *testing.T) { + svc := &GatewayService{} + future := time.Now().Add(5 * time.Minute) + + acc := &Account{ + Platform: PlatformAnthropic, + Status: StatusActive, + Schedulable: true, + RateLimitResetAt: &future, + } + + if svc.isAccountSchedulableForSelection(acc) { + t.Fatalf("expected non-sora account to keep generic schedulable checks") + } +} + +func TestGatewayServiceIsAccountSchedulableForModelSelectionSoraChecksModelScopeOnly(t *testing.T) { + svc := &GatewayService{} + model := "sora2-landscape-10s" + resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339) + globalResetAt := time.Now().Add(2 * time.Minute) + + acc := &Account{ + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + RateLimitResetAt: &globalResetAt, + Extra: map[string]any{ + "model_rate_limits": map[string]any{ + model: map[string]any{ + "rate_limit_reset_at": resetAt, + }, + }, + }, + } + + if svc.isAccountSchedulableForModelSelection(context.Background(), acc, model) { + t.Fatalf("expected sora account to be blocked by model scope rate limit") + } +} + +func TestCollectSelectionFailureStatsSoraIgnoresGenericUnschedulableWindows(t *testing.T) { + svc := &GatewayService{} + future := time.Now().Add(3 * time.Minute) + + accounts := []Account{ + { + ID: 1, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + RateLimitResetAt: &future, + }, + } + + stats := svc.collectSelectionFailureStats(context.Background(), accounts, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false) + if stats.Unschedulable != 0 || stats.Eligible != 1 { + t.Fatalf("unexpected stats: unschedulable=%d eligible=%d", stats.Unschedulable, stats.Eligible) + } +} diff --git a/backend/internal/service/gateway_service_streaming_test.go b/backend/internal/service/gateway_service_streaming_test.go new file mode 100644 index 00000000..c8803d39 --- /dev/null +++ b/backend/internal/service/gateway_service_streaming_test.go @@ -0,0 +1,52 @@ +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + + svc := &GatewayService{ + cfg: cfg, + rateLimitService: &RateLimitService{}, + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + defer func() { _ = pw.Close() }() + // Minimal SSE event to trigger parseSSEUsage + _, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":3}}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":7}}\n\n")) + _, _ = pw.Write([]byte("data: [DONE]\n\n")) + }() + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 3, result.usage.InputTokens) + require.Equal(t, 7, result.usage.OutputTokens) +} diff --git a/backend/internal/service/gateway_streaming_test.go b/backend/internal/service/gateway_streaming_test.go new file mode 100644 index 00000000..cd690cbd --- /dev/null +++ b/backend/internal/service/gateway_streaming_test.go @@ -0,0 +1,219 @@ +//go:build unit + +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// --- parseSSEUsage 测试 --- + +func newMinimalGatewayService() *GatewayService { + return &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } +} + +func TestParseSSEUsage_MessageStart(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + data := `{"type":"message_start","message":{"usage":{"input_tokens":100,"cache_creation_input_tokens":50,"cache_read_input_tokens":200}}}` + svc.parseSSEUsage(data, usage) + + require.Equal(t, 100, usage.InputTokens) + require.Equal(t, 50, usage.CacheCreationInputTokens) + require.Equal(t, 200, usage.CacheReadInputTokens) + require.Equal(t, 0, usage.OutputTokens, "message_start 不应设置 output_tokens") +} + +func TestParseSSEUsage_MessageDelta(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + data := `{"type":"message_delta","usage":{"output_tokens":42}}` + svc.parseSSEUsage(data, usage) + + require.Equal(t, 42, usage.OutputTokens) + require.Equal(t, 0, usage.InputTokens, "message_delta 的 output_tokens 不应影响已有的 input_tokens") +} + +func TestParseSSEUsage_DeltaDoesNotOverwriteStartValues(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // 先处理 message_start + svc.parseSSEUsage(`{"type":"message_start","message":{"usage":{"input_tokens":100}}}`, usage) + require.Equal(t, 100, usage.InputTokens) + + // 再处理 message_delta(output_tokens > 0, input_tokens = 0) + svc.parseSSEUsage(`{"type":"message_delta","usage":{"output_tokens":50}}`, usage) + require.Equal(t, 100, usage.InputTokens, "delta 中 input_tokens=0 不应覆盖 start 中的值") + require.Equal(t, 50, usage.OutputTokens) +} + +func TestParseSSEUsage_DeltaOverwritesWithNonZero(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // GLM 等 API 会在 delta 中包含所有 usage 信息 + svc.parseSSEUsage(`{"type":"message_delta","usage":{"input_tokens":200,"output_tokens":100,"cache_creation_input_tokens":30,"cache_read_input_tokens":60}}`, usage) + require.Equal(t, 200, usage.InputTokens) + require.Equal(t, 100, usage.OutputTokens) + require.Equal(t, 30, usage.CacheCreationInputTokens) + require.Equal(t, 60, usage.CacheReadInputTokens) +} + +func TestParseSSEUsage_DeltaDoesNotResetCacheCreationBreakdown(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // 先在 message_start 中写入非零 5m/1h 明细 + svc.parseSSEUsage(`{"type":"message_start","message":{"usage":{"input_tokens":100,"cache_creation":{"ephemeral_5m_input_tokens":30,"ephemeral_1h_input_tokens":70}}}}`, usage) + require.Equal(t, 30, usage.CacheCreation5mTokens) + require.Equal(t, 70, usage.CacheCreation1hTokens) + + // 后续 delta 带默认 0,不应覆盖已有非零值 + svc.parseSSEUsage(`{"type":"message_delta","usage":{"output_tokens":12,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0}}}`, usage) + require.Equal(t, 30, usage.CacheCreation5mTokens, "delta 的 0 值不应重置 5m 明细") + require.Equal(t, 70, usage.CacheCreation1hTokens, "delta 的 0 值不应重置 1h 明细") + require.Equal(t, 12, usage.OutputTokens) +} + +func TestParseSSEUsage_InvalidJSON(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // 无效 JSON 不应 panic + svc.parseSSEUsage("not json", usage) + require.Equal(t, 0, usage.InputTokens) + require.Equal(t, 0, usage.OutputTokens) +} + +func TestParseSSEUsage_UnknownType(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // 不是 message_start 或 message_delta 的类型 + svc.parseSSEUsage(`{"type":"content_block_delta","delta":{"text":"hello"}}`, usage) + require.Equal(t, 0, usage.InputTokens) + require.Equal(t, 0, usage.OutputTokens) +} + +func TestParseSSEUsage_EmptyString(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + svc.parseSSEUsage("", usage) + require.Equal(t, 0, usage.InputTokens) +} + +func TestParseSSEUsage_DoneEvent(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // [DONE] 事件不应影响 usage + svc.parseSSEUsage("[DONE]", usage) + require.Equal(t, 0, usage.InputTokens) +} + +// --- 流式响应端到端测试 --- + +func TestHandleStreamingResponse_CacheTokens(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newMinimalGatewayService() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":10,\"cache_creation_input_tokens\":20,\"cache_read_input_tokens\":30}}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":15}}\n\n")) + _, _ = pw.Write([]byte("data: [DONE]\n\n")) + }() + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 10, result.usage.InputTokens) + require.Equal(t, 15, result.usage.OutputTokens) + require.Equal(t, 20, result.usage.CacheCreationInputTokens) + require.Equal(t, 30, result.usage.CacheReadInputTokens) +} + +func TestHandleStreamingResponse_EmptyStream(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newMinimalGatewayService() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + // 直接关闭,不发送任何事件 + _ = pw.Close() + }() + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) +} + +func TestHandleStreamingResponse_SpecialCharactersInJSON(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newMinimalGatewayService() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + defer func() { _ = pw.Close() }() + // 包含特殊字符的 content_block_delta(引号、换行、Unicode) + _, _ = pw.Write([]byte("data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello \\\"world\\\"\\n你好\"}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":5}}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":3}}\n\n")) + _, _ = pw.Write([]byte("data: [DONE]\n\n")) + }() + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 5, result.usage.InputTokens) + require.Equal(t, 3, result.usage.OutputTokens) + + // 验证响应中包含转发的数据 + body := rec.Body.String() + require.Contains(t, body, "content_block_delta", "响应应包含转发的 SSE 事件") +} diff --git a/backend/internal/service/gateway_waiting_queue_test.go b/backend/internal/service/gateway_waiting_queue_test.go new file mode 100644 index 00000000..0c53323e --- /dev/null +++ b/backend/internal/service/gateway_waiting_queue_test.go @@ -0,0 +1,120 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestDecrementWaitCount_NilCache 确保 nil cache 不会 panic +func TestDecrementWaitCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + // 不应 panic + svc.DecrementWaitCount(context.Background(), 1) +} + +// TestDecrementWaitCount_CacheError 确保 cache 错误不会传播 +func TestDecrementWaitCount_CacheError(t *testing.T) { + cache := &stubConcurrencyCacheForTest{} + svc := NewConcurrencyService(cache) + // DecrementWaitCount 使用 background context,错误只记录日志不传播 + svc.DecrementWaitCount(context.Background(), 1) +} + +// TestDecrementAccountWaitCount_NilCache 确保 nil cache 不会 panic +func TestDecrementAccountWaitCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + svc.DecrementAccountWaitCount(context.Background(), 1) +} + +// TestDecrementAccountWaitCount_CacheError 确保 cache 错误不会传播 +func TestDecrementAccountWaitCount_CacheError(t *testing.T) { + cache := &stubConcurrencyCacheForTest{} + svc := NewConcurrencyService(cache) + svc.DecrementAccountWaitCount(context.Background(), 1) +} + +// TestWaitingQueueFlow_IncrementThenDecrement 测试完整的等待队列增减流程 +func TestWaitingQueueFlow_IncrementThenDecrement(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitAllowed: true} + svc := NewConcurrencyService(cache) + + // 进入等待队列 + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.True(t, allowed) + + // 离开等待队列(不应 panic) + svc.DecrementWaitCount(context.Background(), 1) +} + +// TestWaitingQueueFlow_AccountLevel 测试账号级等待队列流程 +func TestWaitingQueueFlow_AccountLevel(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitAllowed: true} + svc := NewConcurrencyService(cache) + + // 进入账号等待队列 + allowed, err := svc.IncrementAccountWaitCount(context.Background(), 42, 10) + require.NoError(t, err) + require.True(t, allowed) + + // 离开账号等待队列 + svc.DecrementAccountWaitCount(context.Background(), 42) +} + +// TestWaitingQueueFull_Returns429Signal 测试等待队列满时返回 false +func TestWaitingQueueFull_Returns429Signal(t *testing.T) { + // waitAllowed=false 模拟队列已满 + cache := &stubConcurrencyCacheForTest{waitAllowed: false} + svc := NewConcurrencyService(cache) + + // 用户级等待队列满 + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.False(t, allowed, "等待队列满时应返回 false(调用方根据此返回 429)") + + // 账号级等待队列满 + allowed, err = svc.IncrementAccountWaitCount(context.Background(), 1, 10) + require.NoError(t, err) + require.False(t, allowed, "账号等待队列满时应返回 false") +} + +// TestWaitingQueue_FailOpen_OnCacheError 测试 Redis 故障时 fail-open +func TestWaitingQueue_FailOpen_OnCacheError(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis connection refused")} + svc := NewConcurrencyService(cache) + + // 用户级:Redis 错误时允许通过 + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err, "Redis 错误不应向调用方传播") + require.True(t, allowed, "Redis 故障时应 fail-open 放行") + + // 账号级:同样 fail-open + allowed, err = svc.IncrementAccountWaitCount(context.Background(), 1, 10) + require.NoError(t, err, "Redis 错误不应向调用方传播") + require.True(t, allowed, "Redis 故障时应 fail-open 放行") +} + +// TestCalculateMaxWait_Scenarios 测试最大等待队列大小计算 +func TestCalculateMaxWait_Scenarios(t *testing.T) { + tests := []struct { + concurrency int + expected int + }{ + {5, 25}, // 5 + 20 + {10, 30}, // 10 + 20 + {1, 21}, // 1 + 20 + {0, 21}, // min(1) + 20 + {-1, 21}, // min(1) + 20 + {-10, 21}, // min(1) + 20 + {100, 120}, // 100 + 20 + } + for _, tt := range tests { + result := CalculateMaxWait(tt.concurrency) + require.Equal(t, tt.expected, result, "CalculateMaxWait(%d)", tt.concurrency) + } +} diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index f3abd1dc..a003f636 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -22,10 +22,12 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" ) const geminiStickySessionTTL = time.Hour @@ -51,6 +53,7 @@ type GeminiMessagesCompatService struct { httpUpstream HTTPUpstream antigravityGatewayService *AntigravityGatewayService cfg *config.Config + responseHeaderFilter *responseheaders.CompiledHeaderFilter } func NewGeminiMessagesCompatService( @@ -74,6 +77,7 @@ func NewGeminiMessagesCompatService( httpUpstream: httpUpstream, antigravityGatewayService: antigravityGatewayService, cfg: cfg, + responseHeaderFilter: compileResponseHeaderFilter(cfg), } } @@ -227,6 +231,16 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest( account *Account, requestedModel, platform string, useMixedScheduling bool, +) bool { + return s.isAccountUsableForRequestWithPrecheck(ctx, account, requestedModel, platform, useMixedScheduling, nil) +} + +func (s *GeminiMessagesCompatService) isAccountUsableForRequestWithPrecheck( + ctx context.Context, + account *Account, + requestedModel, platform string, + useMixedScheduling bool, + precheckResult map[int64]bool, ) bool { // 检查模型调度能力 // Check model scheduling capability @@ -248,7 +262,7 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest( // 速率限制预检 // Rate limit precheck - if !s.passesRateLimitPreCheck(ctx, account, requestedModel) { + if !s.passesRateLimitPreCheckWithCache(ctx, account, requestedModel, precheckResult) { return false } @@ -270,18 +284,20 @@ func (s *GeminiMessagesCompatService) isAccountValidForPlatform(account *Account return false } -// passesRateLimitPreCheck 执行速率限制预检。 -// 返回 true 表示通过预检或无需预检。 -// -// passesRateLimitPreCheck performs rate limit precheck. -// Returns true if passed or precheck not required. -func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Context, account *Account, requestedModel string) bool { +func (s *GeminiMessagesCompatService) passesRateLimitPreCheckWithCache(ctx context.Context, account *Account, requestedModel string, precheckResult map[int64]bool) bool { if s.rateLimitService == nil || requestedModel == "" { return true } + + if precheckResult != nil { + if ok, exists := precheckResult[account.ID]; exists { + return ok + } + } + ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel) if err != nil { - log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheck] Account %d precheck error: %v", account.ID, err) } return ok } @@ -300,6 +316,7 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount( useMixedScheduling bool, ) *Account { var selected *Account + precheckResult := s.buildPreCheckUsageResultMap(ctx, accounts, requestedModel) for i := range accounts { acc := &accounts[i] @@ -310,7 +327,7 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount( } // 检查账号是否可用于当前请求 - if !s.isAccountUsableForRequest(ctx, acc, requestedModel, platform, useMixedScheduling) { + if !s.isAccountUsableForRequestWithPrecheck(ctx, acc, requestedModel, platform, useMixedScheduling, precheckResult) { continue } @@ -328,6 +345,23 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount( return selected } +func (s *GeminiMessagesCompatService) buildPreCheckUsageResultMap(ctx context.Context, accounts []Account, requestedModel string) map[int64]bool { + if s.rateLimitService == nil || requestedModel == "" || len(accounts) == 0 { + return nil + } + + candidates := make([]*Account, 0, len(accounts)) + for i := range accounts { + candidates = append(candidates, &accounts[i]) + } + + result, err := s.rateLimitService.PreCheckUsageBatch(ctx, candidates, requestedModel) + if err != nil { + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheckBatch] failed: %v", err) + } + return result +} + // isBetterGeminiAccount 判断 candidate 是否比 current 更优。 // 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先(OAuth > 非 OAuth),其次是最久未使用的。 // @@ -397,7 +431,10 @@ func (s *GeminiMessagesCompatService) listSchedulableAccountsOnce(ctx context.Co if groupID != nil { return s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, queryPlatforms) } - return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms) + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + return s.accountRepo.ListSchedulableByPlatforms(ctx, queryPlatforms) + } + return s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, queryPlatforms) } func (s *GeminiMessagesCompatService) validateUpstreamBaseURL(raw string) (string, error) { @@ -697,7 +734,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex Message: safeErr, }) if attempt < geminiMaxRetries { - log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err) + logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err) sleepGeminiBackoff(attempt) continue } @@ -753,7 +790,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex } retryGeminiReq, txErr := convertClaudeMessagesToGeminiGenerateContent(strippedClaudeBody) if txErr == nil { - log.Printf("Gemini account %d: detected signature-related 400, retrying with downgraded Claude blocks (%s)", account.ID, stageName) + logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: detected signature-related 400, retrying with downgraded Claude blocks (%s)", account.ID, stageName) geminiReq = retryGeminiReq // Consume one retry budget attempt and continue with the updated request payload. sleepGeminiBackoff(1) @@ -820,7 +857,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex Detail: upstreamDetail, }) - log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries) + logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries) sleepGeminiBackoff(attempt) continue } @@ -968,7 +1005,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex if err != nil { return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream stream") } - claudeResp, usageObj2 := convertGeminiToClaudeMessage(collected, originalModel) + collectedBytes, _ := json.Marshal(collected) + claudeResp, usageObj2 := convertGeminiToClaudeMessage(collected, originalModel, collectedBytes) c.JSON(http.StatusOK, claudeResp) usage = usageObj2 if usageObj != nil && (usageObj.InputTokens > 0 || usageObj.OutputTokens > 0) { @@ -1195,7 +1233,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. Message: safeErr, }) if attempt < geminiMaxRetries { - log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err) + logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err) sleepGeminiBackoff(attempt) continue } @@ -1264,7 +1302,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. Detail: upstreamDetail, }) - log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries) + logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries) sleepGeminiBackoff(attempt) continue } @@ -1424,7 +1462,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. maxBytes = 2048 } upstreamDetail = truncateString(string(respBody), maxBytes) - log.Printf("[Gemini] native upstream error %d: %s", resp.StatusCode, truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini] native upstream error %d: %s", resp.StatusCode, truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)) } setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ @@ -1601,7 +1639,7 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, acc }) if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf("[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)) } if status, errType, errMsg, matched := applyErrorPassthroughRule( @@ -1821,12 +1859,17 @@ func (s *GeminiMessagesCompatService) handleNonStreamingResponse(c *gin.Context, return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response") } - geminiResp, err := unwrapGeminiResponse(body) + unwrappedBody, err := unwrapGeminiResponse(body) if err != nil { return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") } - claudeResp, usage := convertGeminiToClaudeMessage(geminiResp, originalModel) + var geminiResp map[string]any + if err := json.Unmarshal(unwrappedBody, &geminiResp); err != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") + } + + claudeResp, usage := convertGeminiToClaudeMessage(geminiResp, originalModel, unwrappedBody) c.JSON(http.StatusOK, claudeResp) return usage, nil @@ -1899,11 +1942,16 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re continue } - geminiResp, err := unwrapGeminiResponse([]byte(payload)) + unwrappedBytes, err := unwrapGeminiResponse([]byte(payload)) if err != nil { continue } + var geminiResp map[string]any + if err := json.Unmarshal(unwrappedBytes, &geminiResp); err != nil { + continue + } + if fr := extractGeminiFinishReason(geminiResp); fr != "" { finishReason = fr } @@ -2030,7 +2078,7 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re } } - if u := extractGeminiUsage(geminiResp); u != nil { + if u := extractGeminiUsage(unwrappedBytes); u != nil { usage = *u } @@ -2121,11 +2169,7 @@ func unwrapIfNeeded(isOAuth bool, raw []byte) []byte { if err != nil { return raw } - b, err := json.Marshal(inner) - if err != nil { - return raw - } - return b + return inner } func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsage, error) { @@ -2149,17 +2193,20 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag } default: var parsed map[string]any + var rawBytes []byte if isOAuth { - inner, err := unwrapGeminiResponse([]byte(payload)) - if err == nil && inner != nil { - parsed = inner + innerBytes, err := unwrapGeminiResponse([]byte(payload)) + if err == nil { + rawBytes = innerBytes + _ = json.Unmarshal(innerBytes, &parsed) } } else { - _ = json.Unmarshal([]byte(payload), &parsed) + rawBytes = []byte(payload) + _ = json.Unmarshal(rawBytes, &parsed) } if parsed != nil { last = parsed - if u := extractGeminiUsage(parsed); u != nil { + if u := extractGeminiUsage(rawBytes); u != nil { usage = u } if parts := extractGeminiParts(parsed); len(parts) > 0 { @@ -2288,53 +2335,27 @@ func isGeminiInsufficientScope(headers http.Header, body []byte) bool { } func estimateGeminiCountTokens(reqBody []byte) int { - var obj map[string]any - if err := json.Unmarshal(reqBody, &obj); err != nil { - return 0 - } - - var texts []string + total := 0 // systemInstruction.parts[].text - if si, ok := obj["systemInstruction"].(map[string]any); ok { - if parts, ok := si["parts"].([]any); ok { - for _, p := range parts { - if pm, ok := p.(map[string]any); ok { - if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" { - texts = append(texts, t) - } - } - } + gjson.GetBytes(reqBody, "systemInstruction.parts").ForEach(func(_, part gjson.Result) bool { + if t := strings.TrimSpace(part.Get("text").String()); t != "" { + total += estimateTokensForText(t) } - } + return true + }) // contents[].parts[].text - if contents, ok := obj["contents"].([]any); ok { - for _, c := range contents { - cm, ok := c.(map[string]any) - if !ok { - continue + gjson.GetBytes(reqBody, "contents").ForEach(func(_, content gjson.Result) bool { + content.Get("parts").ForEach(func(_, part gjson.Result) bool { + if t := strings.TrimSpace(part.Get("text").String()); t != "" { + total += estimateTokensForText(t) } - parts, ok := cm["parts"].([]any) - if !ok { - continue - } - for _, p := range parts { - pm, ok := p.(map[string]any) - if !ok { - continue - } - if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" { - texts = append(texts, t) - } - } - } - } + return true + }) + return true + }) - total := 0 - for _, t := range texts { - total += estimateTokensForText(t) - } if total < 0 { return 0 } @@ -2372,31 +2393,39 @@ type UpstreamHTTPResult struct { } func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) { - // Log response headers for debugging - log.Printf("[GeminiAPI] ========== Response Headers ==========") - for key, values := range resp.Header { - if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { - log.Printf("[GeminiAPI] %s: %v", key, values) + if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========") + for key, values := range resp.Header { + if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values) + } } + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================") } - log.Printf("[GeminiAPI] ========================================") - respBody, err := io.ReadAll(resp.Body) + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } return nil, err } - var parsed map[string]any if isOAuth { - parsed, err = unwrapGeminiResponse(respBody) - if err == nil && parsed != nil { - respBody, _ = json.Marshal(parsed) + unwrappedBody, uwErr := unwrapGeminiResponse(respBody) + if uwErr == nil { + respBody = unwrappedBody } - } else { - _ = json.Unmarshal(respBody, &parsed) } - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := resp.Header.Get("Content-Type") if contentType == "" { @@ -2404,26 +2433,25 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co } c.Data(resp.StatusCode, contentType, respBody) - if parsed != nil { - if u := extractGeminiUsage(parsed); u != nil { - return u, nil - } + if u := extractGeminiUsage(respBody); u != nil { + return u, nil } return &ClaudeUsage{}, nil } func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) { - // Log response headers for debugging - log.Printf("[GeminiAPI] ========== Streaming Response Headers ==========") - for key, values := range resp.Header { - if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { - log.Printf("[GeminiAPI] %s: %v", key, values) + if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========") + for key, values := range resp.Header { + if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values) + } } + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================") } - log.Printf("[GeminiAPI] ====================================================") - if s.cfg != nil { - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) } c.Status(resp.StatusCode) @@ -2460,23 +2488,19 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte var rawToWrite string rawToWrite = payload - var parsed map[string]any + var rawBytes []byte if isOAuth { - inner, err := unwrapGeminiResponse([]byte(payload)) - if err == nil && inner != nil { - parsed = inner - if b, err := json.Marshal(inner); err == nil { - rawToWrite = string(b) - } + innerBytes, err := unwrapGeminiResponse([]byte(payload)) + if err == nil { + rawToWrite = string(innerBytes) + rawBytes = innerBytes } } else { - _ = json.Unmarshal([]byte(payload), &parsed) + rawBytes = []byte(payload) } - if parsed != nil { - if u := extractGeminiUsage(parsed); u != nil { - usage = u - } + if u := extractGeminiUsage(rawBytes); u != nil { + usage = u } if firstTokenMs == nil { @@ -2568,7 +2592,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20)) wwwAuthenticate := resp.Header.Get("Www-Authenticate") - filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders) + filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.responseHeaderFilter) if wwwAuthenticate != "" { filteredHeaders.Set("Www-Authenticate", wwwAuthenticate) } @@ -2579,19 +2603,18 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac }, nil } -func unwrapGeminiResponse(raw []byte) (map[string]any, error) { - var outer map[string]any - if err := json.Unmarshal(raw, &outer); err != nil { - return nil, err +// unwrapGeminiResponse 解包 Gemini OAuth 响应中的 response 字段 +// 使用 gjson 零拷贝提取,避免完整 Unmarshal+Marshal +func unwrapGeminiResponse(raw []byte) ([]byte, error) { + result := gjson.GetBytes(raw, "response") + if result.Exists() && result.Type == gjson.JSON { + return []byte(result.Raw), nil } - if resp, ok := outer["response"].(map[string]any); ok && resp != nil { - return resp, nil - } - return outer, nil + return raw, nil } -func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel string) (map[string]any, *ClaudeUsage) { - usage := extractGeminiUsage(geminiResp) +func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel string, rawData []byte) (map[string]any, *ClaudeUsage) { + usage := extractGeminiUsage(rawData) if usage == nil { usage = &ClaudeUsage{} } @@ -2655,15 +2678,15 @@ func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel strin return resp, usage } -func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage { - usageMeta, ok := geminiResp["usageMetadata"].(map[string]any) - if !ok || usageMeta == nil { +func extractGeminiUsage(data []byte) *ClaudeUsage { + usage := gjson.GetBytes(data, "usageMetadata") + if !usage.Exists() { return nil } - prompt, _ := asInt(usageMeta["promptTokenCount"]) - cand, _ := asInt(usageMeta["candidatesTokenCount"]) - cached, _ := asInt(usageMeta["cachedContentTokenCount"]) - thoughts, _ := asInt(usageMeta["thoughtsTokenCount"]) + prompt := int(usage.Get("promptTokenCount").Int()) + cand := int(usage.Get("candidatesTokenCount").Int()) + cached := int(usage.Get("cachedContentTokenCount").Int()) + thoughts := int(usage.Get("thoughtsTokenCount").Int()) // 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount, // 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去 return &ClaudeUsage{ @@ -2721,16 +2744,16 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont cooldown = s.rateLimitService.GeminiCooldown(ctx, account) } ra = time.Now().Add(cooldown) - log.Printf("[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second)) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second)) } else { // API Key / AI Studio OAuth: PST 午夜 if ts := nextGeminiDailyResetUnix(); ts != nil { ra = time.Unix(*ts, 0) - log.Printf("[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra) } else { // 兜底:5 分钟 ra = time.Now().Add(5 * time.Minute) - log.Printf("[Gemini 429] Account %d rate limited, fallback to 5min", account.ID) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d rate limited, fallback to 5min", account.ID) } } _ = s.accountRepo.SetRateLimited(ctx, account.ID, ra) @@ -2740,45 +2763,41 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont // 使用解析到的重置时间 resetTime := time.Unix(*resetAt, 0) _ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime) - log.Printf("[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)", + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)", account.ID, resetTime, oauthType, tierID) } // ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳 func ParseGeminiRateLimitResetTime(body []byte) *int64 { - // Try to parse metadata.quotaResetDelay like "12.345s" - var parsed map[string]any - if err := json.Unmarshal(body, &parsed); err == nil { - if errObj, ok := parsed["error"].(map[string]any); ok { - if msg, ok := errObj["message"].(string); ok { - if looksLikeGeminiDailyQuota(msg) { - if ts := nextGeminiDailyResetUnix(); ts != nil { - return ts - } - } - } - if details, ok := errObj["details"].([]any); ok { - for _, d := range details { - dm, ok := d.(map[string]any) - if !ok { - continue - } - if meta, ok := dm["metadata"].(map[string]any); ok { - if v, ok := meta["quotaResetDelay"].(string); ok { - if dur, err := time.ParseDuration(v); err == nil { - // Use ceil to avoid undercounting fractional seconds (e.g. 10.1s should not become 10s), - // which can affect scheduling decisions around thresholds (like 10s). - ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds())) - return &ts - } - } - } - } - } + // 第一阶段:gjson 结构化提取 + errMsg := gjson.GetBytes(body, "error.message").String() + if looksLikeGeminiDailyQuota(errMsg) { + if ts := nextGeminiDailyResetUnix(); ts != nil { + return ts } } - // Match "Please retry in Xs" + // 遍历 error.details 查找 quotaResetDelay + var found *int64 + gjson.GetBytes(body, "error.details").ForEach(func(_, detail gjson.Result) bool { + v := detail.Get("metadata.quotaResetDelay").String() + if v == "" { + return true + } + if dur, err := time.ParseDuration(v); err == nil { + // Use ceil to avoid undercounting fractional seconds (e.g. 10.1s should not become 10s), + // which can affect scheduling decisions around thresholds (like 10s). + ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds())) + found = &ts + return false + } + return true + }) + if found != nil { + return found + } + + // 第二阶段:regex 回退匹配 "Please retry in Xs" matches := retryInRegex.FindStringSubmatch(string(body)) if len(matches) == 2 { if dur, err := time.ParseDuration(matches[1] + "s"); err == nil { diff --git a/backend/internal/service/gemini_messages_compat_service_test.go b/backend/internal/service/gemini_messages_compat_service_test.go index 5bc26973..7560f480 100644 --- a/backend/internal/service/gemini_messages_compat_service_test.go +++ b/backend/internal/service/gemini_messages_compat_service_test.go @@ -2,9 +2,16 @@ package service import ( "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" "strings" "testing" + "time" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -131,6 +138,38 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) { } } +func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLogs(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + svc := &GeminiMessagesCompatService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + GeminiDebugResponseHeaders: false, + }, + }, + } + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "X-RateLimit-Limit": []string{"60"}, + }, + Body: io.NopCloser(strings.NewReader(`{"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":2}}`)), + } + + usage, err := svc.handleNativeNonStreamingResponse(c, resp, false) + require.NoError(t, err) + require.NotNil(t, usage) + require.False(t, logSink.ContainsMessage("[GeminiAPI]"), "debug 关闭时不应输出 Gemini 响应头日志") +} + func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) { claudeReq := map[string]any{ "model": "claude-haiku-4-5-20251001", @@ -206,69 +245,323 @@ func TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing(t *testing } } -func TestExtractGeminiUsage_ThoughtsTokenCount(t *testing.T) { +// TestUnwrapGeminiResponse 测试 unwrapGeminiResponse 的各种输入场景 +// 关键区别:只有 response 为 JSON 对象/数组时才解包 +func TestUnwrapGeminiResponse(t *testing.T) { + // 构造 >50KB 的大型 JSON 对象 + largePadding := strings.Repeat("x", 50*1024) + largeInput := []byte(fmt.Sprintf(`{"response":{"id":"big","pad":"%s"}}`, largePadding)) + largeExpected := fmt.Sprintf(`{"id":"big","pad":"%s"}`, largePadding) + tests := []struct { - name string - resp map[string]any - wantInput int - wantOutput int - wantCacheRead int - wantNil bool + name string + input []byte + expected string + wantErr bool }{ { - name: "with thoughtsTokenCount", - resp: map[string]any{ - "usageMetadata": map[string]any{ - "promptTokenCount": float64(100), - "candidatesTokenCount": float64(20), - "thoughtsTokenCount": float64(50), - }, - }, - wantInput: 100, - wantOutput: 70, + name: "正常 response 包装(JSON 对象)", + input: []byte(`{"response":{"key":"val"}}`), + expected: `{"key":"val"}`, }, { - name: "with thoughtsTokenCount and cache", - resp: map[string]any{ - "usageMetadata": map[string]any{ - "promptTokenCount": float64(100), - "candidatesTokenCount": float64(20), - "cachedContentTokenCount": float64(30), - "thoughtsTokenCount": float64(50), - }, - }, - wantInput: 70, - wantOutput: 70, - wantCacheRead: 30, + name: "无包装直接返回", + input: []byte(`{"key":"val"}`), + expected: `{"key":"val"}`, }, { - name: "without thoughtsTokenCount (old model)", - resp: map[string]any{ - "usageMetadata": map[string]any{ - "promptTokenCount": float64(100), - "candidatesTokenCount": float64(20), - }, - }, - wantInput: 100, - wantOutput: 20, + name: "空 JSON", + input: []byte(`{}`), + expected: `{}`, }, { - name: "no usageMetadata", - resp: map[string]any{}, - wantNil: true, + name: "null response 返回原始 body", + input: []byte(`{"response":null}`), + expected: `{"response":null}`, + }, + { + name: "非法 JSON 返回原始 body", + input: []byte(`not json`), + expected: `not json`, + }, + { + name: "response 为基础类型 string 返回原始 body", + input: []byte(`{"response":"hello"}`), + expected: `{"response":"hello"}`, + }, + { + name: "嵌套 response 只解一层", + input: []byte(`{"response":{"response":{"inner":true}}}`), + expected: `{"response":{"inner":true}}`, + }, + { + name: "大型 JSON >50KB", + input: largeInput, + expected: largeExpected, }, } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - usage := extractGeminiUsage(tt.resp) - if tt.wantNil { - require.Nil(t, usage) + got, err := unwrapGeminiResponse(tt.input) + if tt.wantErr { + require.Error(t, err) return } - require.NotNil(t, usage) - require.Equal(t, tt.wantInput, usage.InputTokens) - require.Equal(t, tt.wantOutput, usage.OutputTokens) - require.Equal(t, tt.wantCacheRead, usage.CacheReadInputTokens) + require.NoError(t, err) + require.Equal(t, tt.expected, strings.TrimSpace(string(got))) + }) + } +} + +// --------------------------------------------------------------------------- +// Task 8.1 — extractGeminiUsage 测试 +// --------------------------------------------------------------------------- + +func TestExtractGeminiUsage(t *testing.T) { + tests := []struct { + name string + input string + wantNil bool + wantUsage *ClaudeUsage + }{ + { + name: "完整 usageMetadata", + input: `{"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":50,"cachedContentTokenCount":20}}`, + wantNil: false, + wantUsage: &ClaudeUsage{ + InputTokens: 80, + OutputTokens: 50, + CacheReadInputTokens: 20, + }, + }, + { + name: "包含 thoughtsTokenCount", + input: `{"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":20,"thoughtsTokenCount":50}}`, + wantNil: false, + wantUsage: &ClaudeUsage{ + InputTokens: 100, + OutputTokens: 70, + CacheReadInputTokens: 0, + }, + }, + { + name: "包含 thoughtsTokenCount 与缓存", + input: `{"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":20,"cachedContentTokenCount":30,"thoughtsTokenCount":50}}`, + wantNil: false, + wantUsage: &ClaudeUsage{ + InputTokens: 70, + OutputTokens: 70, + CacheReadInputTokens: 30, + }, + }, + { + name: "缺失 cachedContentTokenCount", + input: `{"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":50}}`, + wantNil: false, + wantUsage: &ClaudeUsage{ + InputTokens: 100, + OutputTokens: 50, + CacheReadInputTokens: 0, + }, + }, + { + name: "无 usageMetadata", + input: `{"candidates":[]}`, + wantNil: true, + }, + { + // gjson 对 null 返回 Exists()=true,因此函数不会返回 nil, + // 而是返回全零的 ClaudeUsage。 + name: "null usageMetadata — gjson Exists 为 true", + input: `{"usageMetadata":null}`, + wantNil: false, + wantUsage: &ClaudeUsage{ + InputTokens: 0, + OutputTokens: 0, + CacheReadInputTokens: 0, + }, + }, + { + name: "零值字段", + input: `{"usageMetadata":{"promptTokenCount":0,"candidatesTokenCount":0,"cachedContentTokenCount":0}}`, + wantNil: false, + wantUsage: &ClaudeUsage{ + InputTokens: 0, + OutputTokens: 0, + CacheReadInputTokens: 0, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractGeminiUsage([]byte(tt.input)) + if tt.wantNil { + if got != nil { + t.Fatalf("期望返回 nil,实际返回 %+v", got) + } + return + } + if got == nil { + t.Fatalf("期望返回非 nil,实际返回 nil") + } + if got.InputTokens != tt.wantUsage.InputTokens { + t.Errorf("InputTokens: 期望 %d,实际 %d", tt.wantUsage.InputTokens, got.InputTokens) + } + if got.OutputTokens != tt.wantUsage.OutputTokens { + t.Errorf("OutputTokens: 期望 %d,实际 %d", tt.wantUsage.OutputTokens, got.OutputTokens) + } + if got.CacheReadInputTokens != tt.wantUsage.CacheReadInputTokens { + t.Errorf("CacheReadInputTokens: 期望 %d,实际 %d", tt.wantUsage.CacheReadInputTokens, got.CacheReadInputTokens) + } + }) + } +} + +// --------------------------------------------------------------------------- +// Task 8.2 — estimateGeminiCountTokens 测试 +// --------------------------------------------------------------------------- + +func TestEstimateGeminiCountTokens(t *testing.T) { + tests := []struct { + name string + input string + wantGt0 bool // 期望结果 > 0 + wantExact *int // 如果非 nil,期望精确匹配 + }{ + { + name: "含 systemInstruction 和 contents", + input: `{ + "systemInstruction":{"parts":[{"text":"You are a helpful assistant."}]}, + "contents":[{"parts":[{"text":"Hello, how are you?"}]}] + }`, + wantGt0: true, + }, + { + name: "仅 contents,无 systemInstruction", + input: `{ + "contents":[{"parts":[{"text":"Hello, how are you?"}]}] + }`, + wantGt0: true, + }, + { + name: "空 parts", + input: `{"contents":[{"parts":[]}]}`, + wantGt0: false, + wantExact: intPtr(0), + }, + { + name: "非文本 parts(inlineData)", + input: `{"contents":[{"parts":[{"inlineData":{"mimeType":"image/png"}}]}]}`, + wantGt0: false, + wantExact: intPtr(0), + }, + { + name: "空白文本", + input: `{"contents":[{"parts":[{"text":" "}]}]}`, + wantGt0: false, + wantExact: intPtr(0), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := estimateGeminiCountTokens([]byte(tt.input)) + if tt.wantExact != nil { + if got != *tt.wantExact { + t.Errorf("期望精确值 %d,实际 %d", *tt.wantExact, got) + } + return + } + if tt.wantGt0 && got <= 0 { + t.Errorf("期望返回 > 0,实际 %d", got) + } + if !tt.wantGt0 && got != 0 { + t.Errorf("期望返回 0,实际 %d", got) + } + }) + } +} + +// --------------------------------------------------------------------------- +// Task 8.3 — ParseGeminiRateLimitResetTime 测试 +// --------------------------------------------------------------------------- + +func TestParseGeminiRateLimitResetTime(t *testing.T) { + tests := []struct { + name string + input string + wantNil bool + approxDelta int64 // 预期的 (返回值 - now) 大约是多少秒 + }{ + { + name: "正常 quotaResetDelay", + input: `{"error":{"details":[{"metadata":{"quotaResetDelay":"12.345s"}}]}}`, + wantNil: false, + approxDelta: 13, // 向上取整 12.345 -> 13 + }, + { + name: "daily quota", + input: `{"error":{"message":"quota per day exceeded"}}`, + wantNil: false, + approxDelta: -1, // 不检查精确 delta,仅检查非 nil + }, + { + name: "无 details 且无 regex 匹配", + input: `{"error":{"message":"rate limit"}}`, + wantNil: true, + }, + { + name: "regex 回退匹配", + input: `Please retry in 30s`, + wantNil: false, + approxDelta: 30, + }, + { + name: "完全无匹配", + input: `{"error":{"code":429}}`, + wantNil: true, + }, + { + name: "非法 JSON 但 regex 回退仍工作", + input: `not json but Please retry in 10s`, + wantNil: false, + approxDelta: 10, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + now := time.Now().Unix() + got := ParseGeminiRateLimitResetTime([]byte(tt.input)) + + if tt.wantNil { + if got != nil { + t.Fatalf("期望返回 nil,实际返回 %d", *got) + } + return + } + + if got == nil { + t.Fatalf("期望返回非 nil,实际返回 nil") + } + + // approxDelta == -1 表示只检查非 nil,不检查具体值(如 daily quota 场景) + if tt.approxDelta == -1 { + // 仅验证返回的时间戳在合理范围内(未来的某个时间) + if *got < now { + t.Errorf("期望返回的时间戳 >= now(%d),实际 %d", now, *got) + } + return + } + + // 使用 +/-2 秒容差进行范围检查 + delta := *got - now + if delta < tt.approxDelta-2 || delta > tt.approxDelta+2 { + t.Errorf("期望 delta 约为 %d 秒(+/-2),实际 delta = %d 秒(返回值=%d, now=%d)", + tt.approxDelta, delta, *got, now) + } }) } } diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 6b1fcecc..9476e984 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -66,6 +66,11 @@ func (m *mockAccountRepoForGemini) Create(ctx context.Context, account *Account) func (m *mockAccountRepoForGemini) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) { return nil, nil } + +func (m *mockAccountRepoForGemini) FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) { + return nil, nil +} + func (m *mockAccountRepoForGemini) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { return nil, nil } @@ -133,6 +138,12 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont } return m.ListSchedulableByPlatforms(ctx, platforms) } +func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + return m.ListSchedulableByPlatform(ctx, platform) +} +func (m *mockAccountRepoForGemini) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + return m.ListSchedulableByPlatforms(ctx, platforms) +} func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index fd2932e6..08a74a37 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "log" "net/http" "regexp" "strconv" @@ -16,6 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) const ( @@ -54,6 +54,7 @@ type GeminiOAuthService struct { proxyRepo ProxyRepository oauthClient GeminiOAuthClient codeAssist GeminiCliCodeAssistClient + driveClient geminicli.DriveClient cfg *config.Config } @@ -66,6 +67,7 @@ func NewGeminiOAuthService( proxyRepo ProxyRepository, oauthClient GeminiOAuthClient, codeAssist GeminiCliCodeAssistClient, + driveClient geminicli.DriveClient, cfg *config.Config, ) *GeminiOAuthService { return &GeminiOAuthService{ @@ -73,6 +75,7 @@ func NewGeminiOAuthService( proxyRepo: proxyRepo, oauthClient: oauthClient, codeAssist: codeAssist, + driveClient: driveClient, cfg: cfg, } } @@ -81,8 +84,7 @@ func (s *GeminiOAuthService) GetOAuthConfig() *GeminiOAuthCapabilities { // AI Studio OAuth is only enabled when the operator configures a custom OAuth client. clientID := strings.TrimSpace(s.cfg.Gemini.OAuth.ClientID) clientSecret := strings.TrimSpace(s.cfg.Gemini.OAuth.ClientSecret) - enabled := clientID != "" && clientSecret != "" && - (clientID != geminicli.GeminiCLIOAuthClientID || clientSecret != geminicli.GeminiCLIOAuthClientSecret) + enabled := clientID != "" && clientSecret != "" && clientID != geminicli.GeminiCLIOAuthClientID return &GeminiOAuthCapabilities{ AIStudioOAuthEnabled: enabled, @@ -151,8 +153,7 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 return nil, err } - isBuiltinClient := effectiveCfg.ClientID == geminicli.GeminiCLIOAuthClientID && - effectiveCfg.ClientSecret == geminicli.GeminiCLIOAuthClientSecret + isBuiltinClient := effectiveCfg.ClientID == geminicli.GeminiCLIOAuthClientID // AI Studio OAuth requires a user-provided OAuth client (built-in Gemini CLI client is scope-restricted). if oauthType == "ai_studio" && isBuiltinClient { @@ -330,27 +331,27 @@ func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string // inferGoogleOneTier infers Google One tier from Drive storage limit func inferGoogleOneTier(storageBytes int64) string { - log.Printf("[GeminiOAuth] inferGoogleOneTier - input: %d bytes (%.2f TB)", storageBytes, float64(storageBytes)/float64(TB)) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - input: %d bytes (%.2f TB)", storageBytes, float64(storageBytes)/float64(TB)) if storageBytes <= 0 { - log.Printf("[GeminiOAuth] inferGoogleOneTier - storageBytes <= 0, returning UNKNOWN") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - storageBytes <= 0, returning UNKNOWN") return GeminiTierGoogleOneUnknown } if storageBytes > StorageTierUnlimited { - log.Printf("[GeminiOAuth] inferGoogleOneTier - > %d bytes (100TB), returning UNLIMITED", StorageTierUnlimited) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - > %d bytes (100TB), returning UNLIMITED", StorageTierUnlimited) return GeminiTierGoogleAIUltra } if storageBytes >= StorageTierAIPremium { - log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (2TB), returning google_ai_pro", StorageTierAIPremium) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - >= %d bytes (2TB), returning google_ai_pro", StorageTierAIPremium) return GeminiTierGoogleAIPro } if storageBytes >= StorageTierFree { - log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (15GB), returning FREE", StorageTierFree) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - >= %d bytes (15GB), returning FREE", StorageTierFree) return GeminiTierGoogleOneFree } - log.Printf("[GeminiOAuth] inferGoogleOneTier - < %d bytes (15GB), returning UNKNOWN", StorageTierFree) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - < %d bytes (15GB), returning UNKNOWN", StorageTierFree) return GeminiTierGoogleOneUnknown } @@ -360,30 +361,29 @@ func inferGoogleOneTier(storageBytes int64) string { // 2. Personal accounts will get 403/404 from cloudaicompanion.googleapis.com // 3. Google consumer (Google One) and enterprise (GCP) systems are physically isolated func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) { - log.Printf("[GeminiOAuth] Starting FetchGoogleOneTier (Google One personal account)") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Starting FetchGoogleOneTier (Google One personal account)") // Use Drive API to infer tier from storage quota (requires drive.readonly scope) - log.Printf("[GeminiOAuth] Calling Drive API for storage quota...") - driveClient := geminicli.NewDriveClient() + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Calling Drive API for storage quota...") - storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL) + storageInfo, err := s.driveClient.GetStorageQuota(ctx, accessToken, proxyURL) if err != nil { // Check if it's a 403 (scope not granted) if strings.Contains(err.Error(), "status 403") { - log.Printf("[GeminiOAuth] Drive API scope not available (403): %v", err) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Drive API scope not available (403): %v", err) return GeminiTierGoogleOneUnknown, nil, err } // Other errors - log.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v", err) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Failed to fetch Drive storage: %v", err) return GeminiTierGoogleOneUnknown, nil, err } - log.Printf("[GeminiOAuth] Drive API response - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)", + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Drive API response - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)", storageInfo.Limit, float64(storageInfo.Limit)/float64(TB), storageInfo.Usage, float64(storageInfo.Usage)/float64(GB)) tierID := inferGoogleOneTier(storageInfo.Limit) - log.Printf("[GeminiOAuth] Inferred tier from storage: %s", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Inferred tier from storage: %s", tierID) return tierID, storageInfo, nil } @@ -443,16 +443,16 @@ func (s *GeminiOAuthService) RefreshAccountGoogleOneTier( } func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) { - log.Printf("[GeminiOAuth] ========== ExchangeCode START ==========") - log.Printf("[GeminiOAuth] SessionID: %s", input.SessionID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== ExchangeCode START ==========") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] SessionID: %s", input.SessionID) session, ok := s.sessionStore.Get(input.SessionID) if !ok { - log.Printf("[GeminiOAuth] ERROR: Session not found or expired") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Session not found or expired") return nil, fmt.Errorf("session not found or expired") } if strings.TrimSpace(input.State) == "" || input.State != session.State { - log.Printf("[GeminiOAuth] ERROR: Invalid state") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Invalid state") return nil, fmt.Errorf("invalid state") } @@ -463,7 +463,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch proxyURL = proxy.URL() } } - log.Printf("[GeminiOAuth] ProxyURL: %s", proxyURL) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ProxyURL: %s", proxyURL) redirectURI := session.RedirectURI @@ -472,8 +472,8 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch if oauthType == "" { oauthType = "code_assist" } - log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType) - log.Printf("[GeminiOAuth] Project ID from session: %s", session.ProjectID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] OAuth Type: %s", oauthType) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Project ID from session: %s", session.ProjectID) // If the session was created for AI Studio OAuth, ensure a custom OAuth client is configured. if oauthType == "ai_studio" { @@ -485,26 +485,25 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch if err != nil { return nil, err } - isBuiltinClient := effectiveCfg.ClientID == geminicli.GeminiCLIOAuthClientID && - effectiveCfg.ClientSecret == geminicli.GeminiCLIOAuthClientSecret + isBuiltinClient := effectiveCfg.ClientID == geminicli.GeminiCLIOAuthClientID if isBuiltinClient { return nil, fmt.Errorf("AI Studio OAuth requires a custom OAuth Client. Please use an AI Studio API Key account, or configure GEMINI_OAUTH_CLIENT_ID / GEMINI_OAUTH_CLIENT_SECRET and re-authorize") } } - // code_assist always uses the built-in client and its fixed redirect URI. - if oauthType == "code_assist" { + // code_assist/google_one always uses the built-in client and its fixed redirect URI. + if oauthType == "code_assist" || oauthType == "google_one" { redirectURI = geminicli.GeminiCLIRedirectURI } tokenResp, err := s.oauthClient.ExchangeCode(ctx, oauthType, input.Code, session.CodeVerifier, redirectURI, proxyURL) if err != nil { - log.Printf("[GeminiOAuth] ERROR: Failed to exchange code: %v", err) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Failed to exchange code: %v", err) return nil, fmt.Errorf("failed to exchange code: %w", err) } - log.Printf("[GeminiOAuth] Token exchange successful") - log.Printf("[GeminiOAuth] Token scope: %s", tokenResp.Scope) - log.Printf("[GeminiOAuth] Token expires_in: %d seconds", tokenResp.ExpiresIn) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Token exchange successful") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Token scope: %s", tokenResp.Scope) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Token expires_in: %d seconds", tokenResp.ExpiresIn) sessionProjectID := strings.TrimSpace(session.ProjectID) s.sessionStore.Delete(input.SessionID) @@ -526,40 +525,40 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch fallbackTierID = canonicalGeminiTierIDForOAuthType(oauthType, session.TierID) } - log.Printf("[GeminiOAuth] ========== Account Type Detection START ==========") - log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== Account Type Detection START ==========") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] OAuth Type: %s", oauthType) // 对于 code_assist 模式,project_id 是必需的,需要调用 Code Assist API // 对于 google_one 模式,使用个人 Google 账号,不需要 project_id,配额由 Google 网关自动识别 // 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API) switch oauthType { case "code_assist": - log.Printf("[GeminiOAuth] Processing code_assist OAuth type") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Processing code_assist OAuth type") if projectID == "" { - log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...") var err error projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) if err != nil { // 记录警告但不阻断流程,允许后续补充 project_id fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err) - log.Printf("[GeminiOAuth] WARNING: Failed to fetch project_id: %v", err) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] WARNING: Failed to fetch project_id: %v", err) } else { - log.Printf("[GeminiOAuth] Successfully fetched project_id: %s, tier_id: %s", projectID, tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Successfully fetched project_id: %s, tier_id: %s", projectID, tierID) } } else { - log.Printf("[GeminiOAuth] User provided project_id: %s, fetching tier_id...", projectID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] User provided project_id: %s, fetching tier_id...", projectID) // 用户手动填了 project_id,仍需调用 LoadCodeAssist 获取 tierID _, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) if err != nil { fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err) - log.Printf("[GeminiOAuth] WARNING: Failed to fetch tier_id: %v", err) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] WARNING: Failed to fetch tier_id: %v", err) } else { tierID = fetchedTierID - log.Printf("[GeminiOAuth] Successfully fetched tier_id: %s", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Successfully fetched tier_id: %s", tierID) } } if strings.TrimSpace(projectID) == "" { - log.Printf("[GeminiOAuth] ERROR: Missing project_id for Code Assist OAuth") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Missing project_id for Code Assist OAuth") return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project") } // Prefer auto-detected tier; fall back to user-selected tier. @@ -567,31 +566,31 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch if tierID == "" { if fallbackTierID != "" { tierID = fallbackTierID - log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID) } else { tierID = GeminiTierGCPStandard - log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Using default tier_id: %s", tierID) } } - log.Printf("[GeminiOAuth] Final code_assist result - project_id: %s, tier_id: %s", projectID, tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Final code_assist result - project_id: %s, tier_id: %s", projectID, tierID) case "google_one": - log.Printf("[GeminiOAuth] Processing google_one OAuth type") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Processing google_one OAuth type") // Google One accounts use cloudaicompanion API, which requires a project_id. // For personal accounts, Google auto-assigns a project_id via the LoadCodeAssist API. if projectID == "" { - log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...") var err error projectID, _, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) if err != nil { - log.Printf("[GeminiOAuth] ERROR: Failed to fetch project_id: %v", err) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Failed to fetch project_id: %v", err) return nil, fmt.Errorf("google One accounts require a project_id, failed to auto-detect: %w", err) } - log.Printf("[GeminiOAuth] Successfully fetched project_id: %s", projectID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Successfully fetched project_id: %s", projectID) } - log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Attempting to fetch Google One tier from Drive API...") // Attempt to fetch Drive storage tier var storageInfo *geminicli.DriveStorageInfo var err error @@ -599,12 +598,12 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch if err != nil { // Log warning but don't block - use fallback fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err) - log.Printf("[GeminiOAuth] WARNING: Failed to fetch Drive tier: %v", err) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] WARNING: Failed to fetch Drive tier: %v", err) tierID = "" } else { - log.Printf("[GeminiOAuth] Successfully fetched Drive tier: %s", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Successfully fetched Drive tier: %s", tierID) if storageInfo != nil { - log.Printf("[GeminiOAuth] Drive storage - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)", + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Drive storage - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)", storageInfo.Limit, float64(storageInfo.Limit)/float64(TB), storageInfo.Usage, float64(storageInfo.Usage)/float64(GB)) } @@ -613,10 +612,10 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch if tierID == "" || tierID == GeminiTierGoogleOneUnknown { if fallbackTierID != "" { tierID = fallbackTierID - log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID) } else { tierID = GeminiTierGoogleOneFree - log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Using default tier_id: %s", tierID) } } fmt.Printf("[GeminiOAuth] Google One tierID after normalization: %s\n", tierID) @@ -639,7 +638,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch "drive_tier_updated_at": time.Now().Format(time.RFC3339), }, } - log.Printf("[GeminiOAuth] ========== ExchangeCode END (google_one with storage info) ==========") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== ExchangeCode END (google_one with storage info) ==========") return tokenInfo, nil } @@ -652,10 +651,10 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch } default: - log.Printf("[GeminiOAuth] Processing %s OAuth type (no tier detection)", oauthType) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Processing %s OAuth type (no tier detection)", oauthType) } - log.Printf("[GeminiOAuth] ========== Account Type Detection END ==========") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== Account Type Detection END ==========") result := &GeminiTokenInfo{ AccessToken: tokenResp.AccessToken, @@ -668,8 +667,8 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch TierID: tierID, OAuthType: oauthType, } - log.Printf("[GeminiOAuth] Final result - OAuth Type: %s, Project ID: %s, Tier ID: %s", result.OAuthType, result.ProjectID, result.TierID) - log.Printf("[GeminiOAuth] ========== ExchangeCode END ==========") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Final result - OAuth Type: %s, Project ID: %s, Tier ID: %s", result.OAuthType, result.ProjectID, result.TierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== ExchangeCode END ==========") return result, nil } @@ -952,23 +951,23 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr registeredTierID := strings.TrimSpace(loadResp.GetTier()) if registeredTierID != "" { // 已注册但未返回 cloudaicompanionProject,这在 Google One 用户中较常见:需要用户自行提供 project_id。 - log.Printf("[GeminiOAuth] User has tier (%s) but no cloudaicompanionProject, trying Cloud Resource Manager...", registeredTierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] User has tier (%s) but no cloudaicompanionProject, trying Cloud Resource Manager...", registeredTierID) // Try to get project from Cloud Resource Manager fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) if fbErr == nil && strings.TrimSpace(fallback) != "" { - log.Printf("[GeminiOAuth] Found project from Cloud Resource Manager: %s", fallback) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Found project from Cloud Resource Manager: %s", fallback) return strings.TrimSpace(fallback), tierID, nil } // No project found - user must provide project_id manually - log.Printf("[GeminiOAuth] No project found from Cloud Resource Manager, user must provide project_id manually") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] No project found from Cloud Resource Manager, user must provide project_id manually") return "", tierID, fmt.Errorf("user is registered (tier: %s) but no project_id available. Please provide Project ID manually in the authorization form, or create a project at https://console.cloud.google.com", registeredTierID) } } // 未检测到 currentTier/paidTier,视为新用户,继续调用 onboardUser - log.Printf("[GeminiOAuth] No currentTier/paidTier found, proceeding with onboardUser (tierID: %s)", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] No currentTier/paidTier found, proceeding with onboardUser (tierID: %s)", tierID) req := &geminicli.OnboardUserRequest{ TierID: tierID, @@ -1046,7 +1045,7 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR ValidateResolvedIP: true, }) if err != nil { - client = &http.Client{Timeout: 30 * time.Second} + return "", fmt.Errorf("create http client failed: %w", err) } resp, err := client.Do(req) diff --git a/backend/internal/service/gemini_oauth_service_test.go b/backend/internal/service/gemini_oauth_service_test.go index 5591eb39..397b581d 100644 --- a/backend/internal/service/gemini_oauth_service_test.go +++ b/backend/internal/service/gemini_oauth_service_test.go @@ -1,17 +1,29 @@ +//go:build unit + package service import ( "context" + "fmt" "net/url" "strings" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) +// ===================== +// 保留原有测试 +// ===================== + func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) { - t.Parallel() + // NOTE: This test sets process env; it must not run in parallel. + // The built-in Gemini CLI client secret is not embedded in this repository. + // Tests set a dummy secret via env to simulate operator-provided configuration. + t.Setenv(geminicli.GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") type testCase struct { name string @@ -89,7 +101,7 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - svc := NewGeminiOAuthService(nil, nil, nil, tt.cfg) + svc := NewGeminiOAuthService(nil, nil, nil, nil, tt.cfg) got, err := svc.GenerateAuthURL(context.Background(), nil, "https://example.com/auth/callback", tt.projectID, tt.oauthType, "") if tt.wantErrSubstr != "" { if err == nil { @@ -128,3 +140,1336 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) { }) } } + +// ===================== +// 新增测试:validateTierID +// ===================== + +func TestValidateTierID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + tierID string + wantErr bool + }{ + {name: "空字符串合法", tierID: "", wantErr: false}, + {name: "正常 tier_id", tierID: "google_one_free", wantErr: false}, + {name: "包含斜杠", tierID: "tier/sub", wantErr: false}, + {name: "包含连字符", tierID: "gcp-standard", wantErr: false}, + {name: "纯数字", tierID: "12345", wantErr: false}, + {name: "超长字符串(65个字符)", tierID: strings.Repeat("a", 65), wantErr: true}, + {name: "刚好64个字符", tierID: strings.Repeat("b", 64), wantErr: false}, + {name: "非法字符_空格", tierID: "tier id", wantErr: true}, + {name: "非法字符_中文", tierID: "tier_中文", wantErr: true}, + {name: "非法字符_特殊符号", tierID: "tier@id", wantErr: true}, + {name: "非法字符_感叹号", tierID: "tier!id", wantErr: true}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := validateTierID(tt.tierID) + if tt.wantErr && err == nil { + t.Fatalf("期望返回错误,但返回 nil") + } + if !tt.wantErr && err != nil { + t.Fatalf("不期望返回错误,但返回: %v", err) + } + }) + } +} + +// ===================== +// 新增测试:canonicalGeminiTierID +// ===================== + +func TestCanonicalGeminiTierID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + raw string + want string + }{ + // 空值 + {name: "空字符串", raw: "", want: ""}, + {name: "纯空白", raw: " ", want: ""}, + + // 已规范化的值(直接返回) + {name: "google_one_free", raw: "google_one_free", want: GeminiTierGoogleOneFree}, + {name: "google_ai_pro", raw: "google_ai_pro", want: GeminiTierGoogleAIPro}, + {name: "google_ai_ultra", raw: "google_ai_ultra", want: GeminiTierGoogleAIUltra}, + {name: "gcp_standard", raw: "gcp_standard", want: GeminiTierGCPStandard}, + {name: "gcp_enterprise", raw: "gcp_enterprise", want: GeminiTierGCPEnterprise}, + {name: "aistudio_free", raw: "aistudio_free", want: GeminiTierAIStudioFree}, + {name: "aistudio_paid", raw: "aistudio_paid", want: GeminiTierAIStudioPaid}, + {name: "google_one_unknown", raw: "google_one_unknown", want: GeminiTierGoogleOneUnknown}, + + // 大小写不敏感 + {name: "Google_One_Free 大写", raw: "Google_One_Free", want: GeminiTierGoogleOneFree}, + {name: "GCP_STANDARD 全大写", raw: "GCP_STANDARD", want: GeminiTierGCPStandard}, + + // legacy 映射: Google One + {name: "AI_PREMIUM -> google_ai_pro", raw: "AI_PREMIUM", want: GeminiTierGoogleAIPro}, + {name: "FREE -> google_one_free", raw: "FREE", want: GeminiTierGoogleOneFree}, + {name: "GOOGLE_ONE_BASIC -> google_one_free", raw: "GOOGLE_ONE_BASIC", want: GeminiTierGoogleOneFree}, + {name: "GOOGLE_ONE_STANDARD -> google_one_free", raw: "GOOGLE_ONE_STANDARD", want: GeminiTierGoogleOneFree}, + {name: "GOOGLE_ONE_UNLIMITED -> google_ai_ultra", raw: "GOOGLE_ONE_UNLIMITED", want: GeminiTierGoogleAIUltra}, + {name: "GOOGLE_ONE_UNKNOWN -> google_one_unknown", raw: "GOOGLE_ONE_UNKNOWN", want: GeminiTierGoogleOneUnknown}, + + // legacy 映射: Code Assist + {name: "STANDARD -> gcp_standard", raw: "STANDARD", want: GeminiTierGCPStandard}, + {name: "PRO -> gcp_standard", raw: "PRO", want: GeminiTierGCPStandard}, + {name: "LEGACY -> gcp_standard", raw: "LEGACY", want: GeminiTierGCPStandard}, + {name: "ENTERPRISE -> gcp_enterprise", raw: "ENTERPRISE", want: GeminiTierGCPEnterprise}, + {name: "ULTRA -> gcp_enterprise", raw: "ULTRA", want: GeminiTierGCPEnterprise}, + + // kebab-case + {name: "standard-tier -> gcp_standard", raw: "standard-tier", want: GeminiTierGCPStandard}, + {name: "pro-tier -> gcp_standard", raw: "pro-tier", want: GeminiTierGCPStandard}, + {name: "ultra-tier -> gcp_enterprise", raw: "ultra-tier", want: GeminiTierGCPEnterprise}, + + // 未知值 + {name: "unknown_value -> 空", raw: "unknown_value", want: ""}, + {name: "random-text -> 空", raw: "random-text", want: ""}, + + // 带空白 + {name: "带前后空白", raw: " google_one_free ", want: GeminiTierGoogleOneFree}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := canonicalGeminiTierID(tt.raw) + if got != tt.want { + t.Fatalf("canonicalGeminiTierID(%q) = %q, want %q", tt.raw, got, tt.want) + } + }) + } +} + +// ===================== +// 新增测试:canonicalGeminiTierIDForOAuthType +// ===================== + +func TestCanonicalGeminiTierIDForOAuthType(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + oauthType string + tierID string + want string + }{ + // google_one 类型过滤 + {name: "google_one + google_one_free", oauthType: "google_one", tierID: "google_one_free", want: GeminiTierGoogleOneFree}, + {name: "google_one + google_ai_pro", oauthType: "google_one", tierID: "google_ai_pro", want: GeminiTierGoogleAIPro}, + {name: "google_one + google_ai_ultra", oauthType: "google_one", tierID: "google_ai_ultra", want: GeminiTierGoogleAIUltra}, + {name: "google_one + gcp_standard 被过滤", oauthType: "google_one", tierID: "gcp_standard", want: ""}, + {name: "google_one + aistudio_free 被过滤", oauthType: "google_one", tierID: "aistudio_free", want: ""}, + {name: "google_one + AI_PREMIUM 遗留映射", oauthType: "google_one", tierID: "AI_PREMIUM", want: GeminiTierGoogleAIPro}, + + // code_assist 类型过滤 + {name: "code_assist + gcp_standard", oauthType: "code_assist", tierID: "gcp_standard", want: GeminiTierGCPStandard}, + {name: "code_assist + gcp_enterprise", oauthType: "code_assist", tierID: "gcp_enterprise", want: GeminiTierGCPEnterprise}, + {name: "code_assist + google_one_free 被过滤", oauthType: "code_assist", tierID: "google_one_free", want: ""}, + {name: "code_assist + aistudio_free 被过滤", oauthType: "code_assist", tierID: "aistudio_free", want: ""}, + {name: "code_assist + STANDARD 遗留映射", oauthType: "code_assist", tierID: "STANDARD", want: GeminiTierGCPStandard}, + {name: "code_assist + standard-tier kebab", oauthType: "code_assist", tierID: "standard-tier", want: GeminiTierGCPStandard}, + + // ai_studio 类型过滤 + {name: "ai_studio + aistudio_free", oauthType: "ai_studio", tierID: "aistudio_free", want: GeminiTierAIStudioFree}, + {name: "ai_studio + aistudio_paid", oauthType: "ai_studio", tierID: "aistudio_paid", want: GeminiTierAIStudioPaid}, + {name: "ai_studio + gcp_standard 被过滤", oauthType: "ai_studio", tierID: "gcp_standard", want: ""}, + {name: "ai_studio + google_one_free 被过滤", oauthType: "ai_studio", tierID: "google_one_free", want: ""}, + + // 空值 + {name: "空 tierID", oauthType: "google_one", tierID: "", want: ""}, + {name: "空 oauthType + 有效 tierID", oauthType: "", tierID: "gcp_standard", want: GeminiTierGCPStandard}, + {name: "未知 oauthType 接受规范化值", oauthType: "unknown_type", tierID: "gcp_standard", want: GeminiTierGCPStandard}, + + // oauthType 大小写和空白 + {name: "GOOGLE_ONE 大写", oauthType: "GOOGLE_ONE", tierID: "google_one_free", want: GeminiTierGoogleOneFree}, + {name: "oauthType 带空白", oauthType: " code_assist ", tierID: "gcp_standard", want: GeminiTierGCPStandard}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := canonicalGeminiTierIDForOAuthType(tt.oauthType, tt.tierID) + if got != tt.want { + t.Fatalf("canonicalGeminiTierIDForOAuthType(%q, %q) = %q, want %q", tt.oauthType, tt.tierID, got, tt.want) + } + }) + } +} + +// ===================== +// 新增测试:extractTierIDFromAllowedTiers +// ===================== + +func TestExtractTierIDFromAllowedTiers(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + allowedTiers []geminicli.AllowedTier + want string + }{ + { + name: "nil 列表返回 LEGACY", + allowedTiers: nil, + want: "LEGACY", + }, + { + name: "空列表返回 LEGACY", + allowedTiers: []geminicli.AllowedTier{}, + want: "LEGACY", + }, + { + name: "有 IsDefault 的 tier", + allowedTiers: []geminicli.AllowedTier{ + {ID: "STANDARD", IsDefault: false}, + {ID: "PRO", IsDefault: true}, + {ID: "ENTERPRISE", IsDefault: false}, + }, + want: "PRO", + }, + { + name: "没有 IsDefault 取第一个非空", + allowedTiers: []geminicli.AllowedTier{ + {ID: "STANDARD", IsDefault: false}, + {ID: "ENTERPRISE", IsDefault: false}, + }, + want: "STANDARD", + }, + { + name: "IsDefault 的 ID 为空,取第一个非空", + allowedTiers: []geminicli.AllowedTier{ + {ID: "", IsDefault: true}, + {ID: "PRO", IsDefault: false}, + }, + want: "PRO", + }, + { + name: "所有 ID 都为空返回 LEGACY", + allowedTiers: []geminicli.AllowedTier{ + {ID: "", IsDefault: false}, + {ID: " ", IsDefault: false}, + }, + want: "LEGACY", + }, + { + name: "ID 带空白会被 trim", + allowedTiers: []geminicli.AllowedTier{ + {ID: " STANDARD ", IsDefault: true}, + }, + want: "STANDARD", + }, + { + name: "单个 tier 且 IsDefault", + allowedTiers: []geminicli.AllowedTier{ + {ID: "ENTERPRISE", IsDefault: true}, + }, + want: "ENTERPRISE", + }, + { + name: "单个 tier 非 IsDefault", + allowedTiers: []geminicli.AllowedTier{ + {ID: "ENTERPRISE", IsDefault: false}, + }, + want: "ENTERPRISE", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := extractTierIDFromAllowedTiers(tt.allowedTiers) + if got != tt.want { + t.Fatalf("extractTierIDFromAllowedTiers() = %q, want %q", got, tt.want) + } + }) + } +} + +// ===================== +// 新增测试:inferGoogleOneTier +// ===================== + +func TestInferGoogleOneTier(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + storageBytes int64 + want string + }{ + // 边界:<= 0 + {name: "0 bytes -> unknown", storageBytes: 0, want: GeminiTierGoogleOneUnknown}, + {name: "负数 -> unknown", storageBytes: -1, want: GeminiTierGoogleOneUnknown}, + + // > 100TB -> ultra + {name: "> 100TB -> ultra", storageBytes: int64(StorageTierUnlimited) + 1, want: GeminiTierGoogleAIUltra}, + {name: "200TB -> ultra", storageBytes: 200 * int64(TB), want: GeminiTierGoogleAIUltra}, + + // >= 2TB -> pro (但 <= 100TB) + {name: "正好 2TB -> pro", storageBytes: int64(StorageTierAIPremium), want: GeminiTierGoogleAIPro}, + {name: "5TB -> pro", storageBytes: 5 * int64(TB), want: GeminiTierGoogleAIPro}, + {name: "100TB 正好 -> pro (不是 > 100TB)", storageBytes: int64(StorageTierUnlimited), want: GeminiTierGoogleAIPro}, + + // >= 15GB -> free (但 < 2TB) + {name: "正好 15GB -> free", storageBytes: int64(StorageTierFree), want: GeminiTierGoogleOneFree}, + {name: "100GB -> free", storageBytes: 100 * int64(GB), want: GeminiTierGoogleOneFree}, + {name: "略低于 2TB -> free", storageBytes: int64(StorageTierAIPremium) - 1, want: GeminiTierGoogleOneFree}, + + // < 15GB -> unknown + {name: "1GB -> unknown", storageBytes: int64(GB), want: GeminiTierGoogleOneUnknown}, + {name: "略低于 15GB -> unknown", storageBytes: int64(StorageTierFree) - 1, want: GeminiTierGoogleOneUnknown}, + {name: "1 byte -> unknown", storageBytes: 1, want: GeminiTierGoogleOneUnknown}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := inferGoogleOneTier(tt.storageBytes) + if got != tt.want { + t.Fatalf("inferGoogleOneTier(%d) = %q, want %q", tt.storageBytes, got, tt.want) + } + }) + } +} + +// ===================== +// 新增测试:isNonRetryableGeminiOAuthError +// ===================== + +func TestIsNonRetryableGeminiOAuthError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + want bool + }{ + {name: "invalid_grant", err: fmt.Errorf("error: invalid_grant"), want: true}, + {name: "invalid_client", err: fmt.Errorf("oauth error: invalid_client"), want: true}, + {name: "unauthorized_client", err: fmt.Errorf("unauthorized_client: mismatch"), want: true}, + {name: "access_denied", err: fmt.Errorf("access_denied by user"), want: true}, + {name: "普通网络错误", err: fmt.Errorf("connection timeout"), want: false}, + {name: "HTTP 500 错误", err: fmt.Errorf("server error 500"), want: false}, + {name: "空错误信息", err: fmt.Errorf(""), want: false}, + {name: "包含 invalid 但不是完整匹配", err: fmt.Errorf("invalid request"), want: false}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := isNonRetryableGeminiOAuthError(tt.err) + if got != tt.want { + t.Fatalf("isNonRetryableGeminiOAuthError(%v) = %v, want %v", tt.err, got, tt.want) + } + }) + } +} + +// ===================== +// 新增测试:BuildAccountCredentials +// ===================== + +func TestGeminiOAuthService_BuildAccountCredentials(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{}) + defer svc.Stop() + + t.Run("完整字段", func(t *testing.T) { + t.Parallel() + tokenInfo := &GeminiTokenInfo{ + AccessToken: "access-123", + RefreshToken: "refresh-456", + ExpiresIn: 3600, + ExpiresAt: 1700000000, + TokenType: "Bearer", + Scope: "openid email", + ProjectID: "my-project", + TierID: "gcp_standard", + OAuthType: "code_assist", + Extra: map[string]any{ + "drive_storage_limit": int64(2199023255552), + }, + } + + creds := svc.BuildAccountCredentials(tokenInfo) + + assertCredStr(t, creds, "access_token", "access-123") + assertCredStr(t, creds, "refresh_token", "refresh-456") + assertCredStr(t, creds, "token_type", "Bearer") + assertCredStr(t, creds, "scope", "openid email") + assertCredStr(t, creds, "project_id", "my-project") + assertCredStr(t, creds, "tier_id", "gcp_standard") + assertCredStr(t, creds, "oauth_type", "code_assist") + assertCredStr(t, creds, "expires_at", "1700000000") + + if _, ok := creds["drive_storage_limit"]; !ok { + t.Fatal("extra 字段 drive_storage_limit 未包含在 creds 中") + } + }) + + t.Run("最小字段(仅 access_token 和 expires_at)", func(t *testing.T) { + t.Parallel() + tokenInfo := &GeminiTokenInfo{ + AccessToken: "token-only", + ExpiresAt: 1700000000, + } + + creds := svc.BuildAccountCredentials(tokenInfo) + + assertCredStr(t, creds, "access_token", "token-only") + assertCredStr(t, creds, "expires_at", "1700000000") + + // 可选字段不应存在 + for _, key := range []string{"refresh_token", "token_type", "scope", "project_id", "tier_id", "oauth_type"} { + if _, ok := creds[key]; ok { + t.Fatalf("不应包含空字段 %q", key) + } + } + }) + + t.Run("无效 tier_id 被静默跳过", func(t *testing.T) { + t.Parallel() + tokenInfo := &GeminiTokenInfo{ + AccessToken: "token", + ExpiresAt: 1700000000, + TierID: "tier with spaces", + } + + creds := svc.BuildAccountCredentials(tokenInfo) + + if _, ok := creds["tier_id"]; ok { + t.Fatal("无效 tier_id 不应被存入 creds") + } + }) + + t.Run("超长 tier_id 被静默跳过", func(t *testing.T) { + t.Parallel() + tokenInfo := &GeminiTokenInfo{ + AccessToken: "token", + ExpiresAt: 1700000000, + TierID: strings.Repeat("x", 65), + } + + creds := svc.BuildAccountCredentials(tokenInfo) + + if _, ok := creds["tier_id"]; ok { + t.Fatal("超长 tier_id 不应被存入 creds") + } + }) + + t.Run("无 extra 字段", func(t *testing.T) { + t.Parallel() + tokenInfo := &GeminiTokenInfo{ + AccessToken: "token", + ExpiresAt: 1700000000, + RefreshToken: "rt", + } + + creds := svc.BuildAccountCredentials(tokenInfo) + + // 仅包含基础字段 + if len(creds) != 3 { // access_token, expires_at, refresh_token + t.Fatalf("creds 字段数量不匹配: got=%d want=3, keys=%v", len(creds), credKeys(creds)) + } + }) +} + +// ===================== +// 新增测试:GetOAuthConfig +// ===================== + +func TestGeminiOAuthService_GetOAuthConfig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *config.Config + wantEnabled bool + }{ + { + name: "无自定义 OAuth 客户端", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{}, + }, + }, + wantEnabled: false, + }, + { + name: "仅 ClientID 无 ClientSecret", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: "custom-id", + }, + }, + }, + wantEnabled: false, + }, + { + name: "仅 ClientSecret 无 ClientID", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientSecret: "custom-secret", + }, + }, + }, + wantEnabled: false, + }, + { + name: "使用内置 Gemini CLI ClientID(不算自定义)", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: geminicli.GeminiCLIOAuthClientID, + ClientSecret: "some-secret", + }, + }, + }, + wantEnabled: false, + }, + { + name: "自定义 OAuth 客户端(非内置 ID)", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: "my-custom-client-id", + ClientSecret: "my-custom-client-secret", + }, + }, + }, + wantEnabled: true, + }, + { + name: "带空白的自定义客户端", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: " my-custom-client-id ", + ClientSecret: " my-custom-client-secret ", + }, + }, + }, + wantEnabled: true, + }, + { + name: "纯空白字符串不算配置", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: " ", + ClientSecret: " ", + }, + }, + }, + wantEnabled: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + svc := NewGeminiOAuthService(nil, nil, nil, nil, tt.cfg) + defer svc.Stop() + + result := svc.GetOAuthConfig() + if result.AIStudioOAuthEnabled != tt.wantEnabled { + t.Fatalf("AIStudioOAuthEnabled = %v, want %v", result.AIStudioOAuthEnabled, tt.wantEnabled) + } + // RequiredRedirectURIs 始终包含 AI Studio redirect URI + if len(result.RequiredRedirectURIs) != 1 || result.RequiredRedirectURIs[0] != geminicli.AIStudioOAuthRedirectURI { + t.Fatalf("RequiredRedirectURIs 不匹配: got=%v", result.RequiredRedirectURIs) + } + }) + } +} + +// ===================== +// 新增测试:GeminiOAuthService.Stop +// ===================== + +func TestGeminiOAuthService_Stop_NoPanic(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{}) + + // 调用 Stop 不应 panic + svc.Stop() + // 多次调用也不应 panic + svc.Stop() +} + +// ===================== +// mock: GeminiOAuthClient +// ===================== + +type mockGeminiOAuthClient struct { + exchangeCodeFunc func(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) + refreshTokenFunc func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) +} + +func (m *mockGeminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) { + if m.exchangeCodeFunc != nil { + return m.exchangeCodeFunc(ctx, oauthType, code, codeVerifier, redirectURI, proxyURL) + } + panic("ExchangeCode not implemented") +} + +func (m *mockGeminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + if m.refreshTokenFunc != nil { + return m.refreshTokenFunc(ctx, oauthType, refreshToken, proxyURL) + } + panic("RefreshToken not implemented") +} + +// ===================== +// mock: GeminiCliCodeAssistClient +// ===================== + +type mockGeminiCodeAssistClient struct { + loadCodeAssistFunc func(ctx context.Context, accessToken, proxyURL string, req *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) + onboardUserFunc func(ctx context.Context, accessToken, proxyURL string, req *geminicli.OnboardUserRequest) (*geminicli.OnboardUserResponse, error) +} + +func (m *mockGeminiCodeAssistClient) LoadCodeAssist(ctx context.Context, accessToken, proxyURL string, req *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) { + if m.loadCodeAssistFunc != nil { + return m.loadCodeAssistFunc(ctx, accessToken, proxyURL, req) + } + panic("LoadCodeAssist not implemented") +} + +func (m *mockGeminiCodeAssistClient) OnboardUser(ctx context.Context, accessToken, proxyURL string, req *geminicli.OnboardUserRequest) (*geminicli.OnboardUserResponse, error) { + if m.onboardUserFunc != nil { + return m.onboardUserFunc(ctx, accessToken, proxyURL, req) + } + panic("OnboardUser not implemented") +} + +// ===================== +// mock: ProxyRepository (最小实现) +// ===================== + +type mockGeminiProxyRepo struct { + getByIDFunc func(ctx context.Context, id int64) (*Proxy, error) +} + +func (m *mockGeminiProxyRepo) Create(ctx context.Context, proxy *Proxy) error { panic("not impl") } +func (m *mockGeminiProxyRepo) GetByID(ctx context.Context, id int64) (*Proxy, error) { + if m.getByIDFunc != nil { + return m.getByIDFunc(ctx, id) + } + return nil, fmt.Errorf("proxy not found") +} +func (m *mockGeminiProxyRepo) ListByIDs(ctx context.Context, ids []int64) ([]Proxy, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) Update(ctx context.Context, proxy *Proxy) error { panic("not impl") } +func (m *mockGeminiProxyRepo) Delete(ctx context.Context, id int64) error { panic("not impl") } +func (m *mockGeminiProxyRepo) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) ListActive(ctx context.Context) ([]Proxy, error) { panic("not impl") } +func (m *mockGeminiProxyRepo) ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) { + panic("not impl") +} + +// mockDriveClient implements geminicli.DriveClient for tests. +type mockDriveClient struct { + getStorageQuotaFunc func(ctx context.Context, accessToken, proxyURL string) (*geminicli.DriveStorageInfo, error) +} + +func (m *mockDriveClient) GetStorageQuota(ctx context.Context, accessToken, proxyURL string) (*geminicli.DriveStorageInfo, error) { + if m.getStorageQuotaFunc != nil { + return m.getStorageQuotaFunc(ctx, accessToken, proxyURL) + } + return nil, fmt.Errorf("drive API not available in test") +} + +// ===================== +// 新增测试:GeminiOAuthService.RefreshToken(含重试逻辑) +// ===================== + +func TestGeminiOAuthService_RefreshToken_Success(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "new-access", + RefreshToken: "new-refresh", + TokenType: "Bearer", + ExpiresIn: 3600, + Scope: "openid", + }, nil + }, + } + + svc := NewGeminiOAuthService(nil, client, nil, nil, &config.Config{}) + defer svc.Stop() + + info, err := svc.RefreshToken(context.Background(), "code_assist", "old-refresh", "") + if err != nil { + t.Fatalf("RefreshToken 返回错误: %v", err) + } + if info.AccessToken != "new-access" { + t.Fatalf("AccessToken 不匹配: got=%q", info.AccessToken) + } + if info.RefreshToken != "new-refresh" { + t.Fatalf("RefreshToken 不匹配: got=%q", info.RefreshToken) + } + if info.ExpiresAt == 0 { + t.Fatal("ExpiresAt 不应为 0") + } +} + +func TestGeminiOAuthService_RefreshToken_NonRetryableError(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return nil, fmt.Errorf("invalid_grant: token revoked") + }, + } + + svc := NewGeminiOAuthService(nil, client, nil, nil, &config.Config{}) + defer svc.Stop() + + _, err := svc.RefreshToken(context.Background(), "code_assist", "revoked-token", "") + if err == nil { + t.Fatal("RefreshToken 应返回错误(不可重试的 invalid_grant)") + } + if !strings.Contains(err.Error(), "invalid_grant") { + t.Fatalf("错误应包含 invalid_grant: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_RefreshToken_RetryableError(t *testing.T) { + t.Parallel() + + callCount := 0 + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + callCount++ + if callCount <= 2 { + return nil, fmt.Errorf("temporary network error") + } + return &geminicli.TokenResponse{ + AccessToken: "recovered", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(nil, client, nil, nil, &config.Config{}) + defer svc.Stop() + + info, err := svc.RefreshToken(context.Background(), "code_assist", "rt", "") + if err != nil { + t.Fatalf("RefreshToken 应在重试后成功: %v", err) + } + if info.AccessToken != "recovered" { + t.Fatalf("AccessToken 不匹配: got=%q", info.AccessToken) + } + if callCount < 3 { + t.Fatalf("应至少调用 3 次(2 次失败 + 1 次成功): got=%d", callCount) + } +} + +// ===================== +// 新增测试:GeminiOAuthService.RefreshAccountToken +// ===================== + +func TestGeminiOAuthService_RefreshAccountToken_NotGeminiOAuth(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("应返回错误(非 Gemini OAuth 账号)") + } + if !strings.Contains(err.Error(), "not a Gemini OAuth account") { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "at", + "oauth_type": "code_assist", + }, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("应返回错误(无 refresh_token)") + } + if !strings.Contains(err.Error(), "no refresh token") { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_AIStudio(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "refreshed-at", + RefreshToken: "refreshed-rt", + ExpiresIn: 3600, + TokenType: "Bearer", + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-at", + "refresh_token": "old-rt", + "oauth_type": "ai_studio", + "tier_id": "aistudio_free", + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + if info.AccessToken != "refreshed-at" { + t.Fatalf("AccessToken 不匹配: got=%q", info.AccessToken) + } + if info.OAuthType != "ai_studio" { + t.Fatalf("OAuthType 不匹配: got=%q", info.OAuthType) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_WithProjectID(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "refreshed", + RefreshToken: "new-rt", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-at", + "refresh_token": "old-rt", + "oauth_type": "code_assist", + "project_id": "my-project", + "tier_id": "gcp_standard", + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + if info.ProjectID != "my-project" { + t.Fatalf("ProjectID 应保留: got=%q", info.ProjectID) + } + if info.TierID != GeminiTierGCPStandard { + t.Fatalf("TierID 不匹配: got=%q want=%q", info.TierID, GeminiTierGCPStandard) + } + if info.OAuthType != "code_assist" { + t.Fatalf("OAuthType 不匹配: got=%q", info.OAuthType) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_DefaultOAuthType(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + if oauthType != "code_assist" { + t.Errorf("默认 oauthType 应为 code_assist: got=%q", oauthType) + } + return &geminicli.TokenResponse{ + AccessToken: "refreshed", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{}) + defer svc.Stop() + + // 无 oauth_type 凭据的旧账号 + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "old-rt", + "project_id": "proj", + "tier_id": "STANDARD", + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + if info.OAuthType != "code_assist" { + t.Fatalf("OAuthType 应默认为 code_assist: got=%q", info.OAuthType) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_WithProxy(t *testing.T) { + t.Parallel() + + proxyRepo := &mockGeminiProxyRepo{ + getByIDFunc: func(ctx context.Context, id int64) (*Proxy, error) { + return &Proxy{ + Protocol: "http", + Host: "proxy.test", + Port: 3128, + }, nil + }, + } + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + if proxyURL != "http://proxy.test:3128" { + t.Errorf("proxyURL 不匹配: got=%q", proxyURL) + } + return &geminicli.TokenResponse{ + AccessToken: "refreshed", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(proxyRepo, client, nil, nil, &config.Config{}) + defer svc.Stop() + + proxyID := int64(5) + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + ProxyID: &proxyID, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "code_assist", + "project_id": "proj", + }, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_NoProjectID_AutoDetect(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "at", + ExpiresIn: 3600, + }, nil + }, + } + + codeAssist := &mockGeminiCodeAssistClient{ + loadCodeAssistFunc: func(ctx context.Context, accessToken, proxyURL string, req *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) { + return &geminicli.LoadCodeAssistResponse{ + CloudAICompanionProject: "auto-project-123", + CurrentTier: &geminicli.TierInfo{ID: "STANDARD"}, + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "code_assist", + // 无 project_id,触发 fetchProjectID + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + if info.ProjectID != "auto-project-123" { + t.Fatalf("ProjectID 应为自动检测值: got=%q", info.ProjectID) + } + if info.TierID != GeminiTierGCPStandard { + t.Fatalf("TierID 不匹配: got=%q", info.TierID) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_NoProjectID_FailsEmpty(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "at", + ExpiresIn: 3600, + }, nil + }, + } + + // 返回有 currentTier 但无 cloudaicompanionProject 的响应, + // 使 fetchProjectID 走"已注册用户"路径(尝试 Cloud Resource Manager -> 失败 -> 返回错误), + // 避免走 onboardUser 路径(5 次重试 x 2 秒 = 10 秒超时) + codeAssist := &mockGeminiCodeAssistClient{ + loadCodeAssistFunc: func(ctx context.Context, accessToken, proxyURL string, req *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) { + return &geminicli.LoadCodeAssistResponse{ + CurrentTier: &geminicli.TierInfo{ID: "STANDARD"}, + // 无 CloudAICompanionProject + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "code_assist", + }, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("应返回错误(无法检测 project_id)") + } + if !strings.Contains(err.Error(), "project_id") { + t.Fatalf("错误信息应包含 project_id: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_GoogleOne_FreshCache(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "at", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "google_one", + "project_id": "proj", + "tier_id": "google_ai_pro", + }, + Extra: map[string]any{ + // 缓存刷新时间在 24 小时内 + "drive_tier_updated_at": time.Now().Add(-1 * time.Hour).Format(time.RFC3339), + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + // 缓存新鲜,应使用已有的 tier_id + if info.TierID != GeminiTierGoogleAIPro { + t.Fatalf("TierID 应使用缓存值: got=%q want=%q", info.TierID, GeminiTierGoogleAIPro) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_GoogleOne_NoTierID_DefaultsFree(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "at", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &mockDriveClient{}, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "google_one", + "project_id": "proj", + // 无 tier_id + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + // FetchGoogleOneTier 会被调用但 oauthClient(此处 mock)不实现 Drive API, + // svc.FetchGoogleOneTier 使用真实 DriveClient 会失败,最终回退到默认值。 + // 由于没有 tier_id 且 FetchGoogleOneTier 失败,应默认为 google_one_free + if info.TierID != GeminiTierGoogleOneFree { + t.Fatalf("TierID 应为默认 free: got=%q", info.TierID) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_Fallback(t *testing.T) { + t.Parallel() + + callCount := 0 + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + callCount++ + if oauthType == "code_assist" { + return nil, fmt.Errorf("unauthorized_client: client mismatch") + } + // ai_studio 路径成功 + return &geminicli.TokenResponse{ + AccessToken: "recovered", + ExpiresIn: 3600, + }, nil + }, + } + + // 启用自定义 OAuth 客户端以触发 fallback 路径 + cfg := &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + }, + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, cfg) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "code_assist", + "project_id": "proj", + "tier_id": "gcp_standard", + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 应在 fallback 后成功: %v", err) + } + if info.AccessToken != "recovered" { + t.Fatalf("AccessToken 不匹配: got=%q", info.AccessToken) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_NoFallback(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return nil, fmt.Errorf("unauthorized_client: client mismatch") + }, + } + + // 无自定义 OAuth 客户端,无法 fallback + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "code_assist", + "project_id": "proj", + }, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("应返回错误(无 fallback)") + } + if !strings.Contains(err.Error(), "OAuth client mismatch") { + t.Fatalf("错误应包含 OAuth client mismatch: got=%q", err.Error()) + } +} + +// ===================== +// 新增测试:GeminiOAuthService.ExchangeCode +// ===================== + +func TestGeminiOAuthService_ExchangeCode_SessionNotFound(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{}) + defer svc.Stop() + + _, err := svc.ExchangeCode(context.Background(), &GeminiExchangeCodeInput{ + SessionID: "nonexistent", + State: "some-state", + Code: "some-code", + }) + if err == nil { + t.Fatal("应返回错误(session 不存在)") + } + if !strings.Contains(err.Error(), "session not found") { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_ExchangeCode_InvalidState(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{}) + defer svc.Stop() + + // 手动创建 session(必须设置 CreatedAt,否则会因 TTL 过期被拒绝) + svc.sessionStore.Set("test-session", &geminicli.OAuthSession{ + State: "correct-state", + CodeVerifier: "verifier", + OAuthType: "ai_studio", + CreatedAt: time.Now(), + }) + + _, err := svc.ExchangeCode(context.Background(), &GeminiExchangeCodeInput{ + SessionID: "test-session", + State: "wrong-state", + Code: "code", + }) + if err == nil { + t.Fatal("应返回错误(state 不匹配)") + } + if !strings.Contains(err.Error(), "invalid state") { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_ExchangeCode_EmptyState(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, nil, &config.Config{}) + defer svc.Stop() + + svc.sessionStore.Set("test-session", &geminicli.OAuthSession{ + State: "correct-state", + CodeVerifier: "verifier", + CreatedAt: time.Now(), + }) + + _, err := svc.ExchangeCode(context.Background(), &GeminiExchangeCodeInput{ + SessionID: "test-session", + State: "", + Code: "code", + }) + if err == nil { + t.Fatal("应返回错误(空 state)") + } +} + +// ===================== +// 辅助函数 +// ===================== + +func assertCredStr(t *testing.T, creds map[string]any, key, want string) { + t.Helper() + raw, ok := creds[key] + if !ok { + t.Fatalf("creds 缺少 key=%q", key) + } + got, ok := raw.(string) + if !ok { + t.Fatalf("creds[%q] 不是 string: %T", key, raw) + } + if got != want { + t.Fatalf("creds[%q] = %q, want %q", key, got, want) + } +} + +func credKeys(m map[string]any) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index e9423ddb..6990caca 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -26,6 +26,15 @@ type Group struct { ImagePrice2K *float64 ImagePrice4K *float64 + // Sora 按次计费配置(阶段 1) + SoraImagePrice360 *float64 + SoraImagePrice540 *float64 + SoraVideoPricePerRequest *float64 + SoraVideoPricePerRequestHD *float64 + + // Sora 存储配额 + SoraStorageQuotaBytes int64 + // Claude Code 客户端限制 ClaudeCodeOnly bool FallbackGroupID *int64 @@ -95,6 +104,18 @@ func (g *Group) GetImagePrice(imageSize string) *float64 { } } +// GetSoraImagePrice 根据 Sora 图片尺寸返回价格(360/540) +func (g *Group) GetSoraImagePrice(imageSize string) *float64 { + switch imageSize { + case "360": + return g.SoraImagePrice360 + case "540": + return g.SoraImagePrice540 + default: + return g.SoraImagePrice360 + } +} + // IsGroupContextValid reports whether a group from context has the fields required for routing decisions. func IsGroupContextValid(group *Group) bool { if group == nil { diff --git a/backend/internal/service/idempotency.go b/backend/internal/service/idempotency.go new file mode 100644 index 00000000..2a86bd60 --- /dev/null +++ b/backend/internal/service/idempotency.go @@ -0,0 +1,471 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" + "sync" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/util/logredact" +) + +const ( + IdempotencyStatusProcessing = "processing" + IdempotencyStatusSucceeded = "succeeded" + IdempotencyStatusFailedRetryable = "failed_retryable" +) + +var ( + ErrIdempotencyKeyRequired = infraerrors.BadRequest("IDEMPOTENCY_KEY_REQUIRED", "idempotency key is required") + ErrIdempotencyKeyInvalid = infraerrors.BadRequest("IDEMPOTENCY_KEY_INVALID", "idempotency key is invalid") + ErrIdempotencyKeyConflict = infraerrors.Conflict("IDEMPOTENCY_KEY_CONFLICT", "idempotency key reused with different payload") + ErrIdempotencyInProgress = infraerrors.Conflict("IDEMPOTENCY_IN_PROGRESS", "idempotent request is still processing") + ErrIdempotencyRetryBackoff = infraerrors.Conflict("IDEMPOTENCY_RETRY_BACKOFF", "idempotent request is in retry backoff window") + ErrIdempotencyStoreUnavail = infraerrors.ServiceUnavailable("IDEMPOTENCY_STORE_UNAVAILABLE", "idempotency store unavailable") + ErrIdempotencyInvalidPayload = infraerrors.BadRequest("IDEMPOTENCY_PAYLOAD_INVALID", "failed to normalize request payload") +) + +type IdempotencyRecord struct { + ID int64 + Scope string + IdempotencyKeyHash string + RequestFingerprint string + Status string + ResponseStatus *int + ResponseBody *string + ErrorReason *string + LockedUntil *time.Time + ExpiresAt time.Time + CreatedAt time.Time + UpdatedAt time.Time +} + +type IdempotencyRepository interface { + CreateProcessing(ctx context.Context, record *IdempotencyRecord) (bool, error) + GetByScopeAndKeyHash(ctx context.Context, scope, keyHash string) (*IdempotencyRecord, error) + TryReclaim(ctx context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) + ExtendProcessingLock(ctx context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) + MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error + MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error + DeleteExpired(ctx context.Context, now time.Time, limit int) (int64, error) +} + +type IdempotencyConfig struct { + DefaultTTL time.Duration + SystemOperationTTL time.Duration + ProcessingTimeout time.Duration + FailedRetryBackoff time.Duration + MaxStoredResponseLen int + ObserveOnly bool +} + +func DefaultIdempotencyConfig() IdempotencyConfig { + return IdempotencyConfig{ + DefaultTTL: 24 * time.Hour, + SystemOperationTTL: 1 * time.Hour, + ProcessingTimeout: 30 * time.Second, + FailedRetryBackoff: 5 * time.Second, + MaxStoredResponseLen: 64 * 1024, + ObserveOnly: true, // 默认先观察再强制,避免老客户端立刻中断 + } +} + +type IdempotencyExecuteOptions struct { + Scope string + ActorScope string + Method string + Route string + IdempotencyKey string + Payload any + TTL time.Duration + RequireKey bool +} + +type IdempotencyExecuteResult struct { + Data any + Replayed bool +} + +type IdempotencyCoordinator struct { + repo IdempotencyRepository + cfg IdempotencyConfig +} + +var ( + defaultIdempotencyMu sync.RWMutex + defaultIdempotencySvc *IdempotencyCoordinator +) + +func SetDefaultIdempotencyCoordinator(svc *IdempotencyCoordinator) { + defaultIdempotencyMu.Lock() + defaultIdempotencySvc = svc + defaultIdempotencyMu.Unlock() +} + +func DefaultIdempotencyCoordinator() *IdempotencyCoordinator { + defaultIdempotencyMu.RLock() + defer defaultIdempotencyMu.RUnlock() + return defaultIdempotencySvc +} + +func DefaultWriteIdempotencyTTL() time.Duration { + defaultTTL := DefaultIdempotencyConfig().DefaultTTL + if coordinator := DefaultIdempotencyCoordinator(); coordinator != nil && coordinator.cfg.DefaultTTL > 0 { + return coordinator.cfg.DefaultTTL + } + return defaultTTL +} + +func DefaultSystemOperationIdempotencyTTL() time.Duration { + defaultTTL := DefaultIdempotencyConfig().SystemOperationTTL + if coordinator := DefaultIdempotencyCoordinator(); coordinator != nil && coordinator.cfg.SystemOperationTTL > 0 { + return coordinator.cfg.SystemOperationTTL + } + return defaultTTL +} + +func NewIdempotencyCoordinator(repo IdempotencyRepository, cfg IdempotencyConfig) *IdempotencyCoordinator { + return &IdempotencyCoordinator{ + repo: repo, + cfg: cfg, + } +} + +func NormalizeIdempotencyKey(raw string) (string, error) { + key := strings.TrimSpace(raw) + if key == "" { + return "", nil + } + if len(key) > 128 { + return "", ErrIdempotencyKeyInvalid + } + for _, r := range key { + if r < 33 || r > 126 { + return "", ErrIdempotencyKeyInvalid + } + } + return key, nil +} + +func HashIdempotencyKey(key string) string { + sum := sha256.Sum256([]byte(key)) + return hex.EncodeToString(sum[:]) +} + +func BuildIdempotencyFingerprint(method, route, actorScope string, payload any) (string, error) { + if method == "" { + method = "POST" + } + if route == "" { + route = "/" + } + if actorScope == "" { + actorScope = "anonymous" + } + + raw, err := json.Marshal(payload) + if err != nil { + return "", ErrIdempotencyInvalidPayload.WithCause(err) + } + sum := sha256.Sum256([]byte( + strings.ToUpper(method) + "\n" + route + "\n" + actorScope + "\n" + string(raw), + )) + return hex.EncodeToString(sum[:]), nil +} + +func RetryAfterSecondsFromError(err error) int { + appErr := new(infraerrors.ApplicationError) + if !errors.As(err, &appErr) || appErr == nil || appErr.Metadata == nil { + return 0 + } + v := strings.TrimSpace(appErr.Metadata["retry_after"]) + if v == "" { + return 0 + } + seconds, convErr := strconv.Atoi(v) + if convErr != nil || seconds <= 0 { + return 0 + } + return seconds +} + +func (c *IdempotencyCoordinator) Execute( + ctx context.Context, + opts IdempotencyExecuteOptions, + execute func(context.Context) (any, error), +) (*IdempotencyExecuteResult, error) { + if execute == nil { + return nil, infraerrors.InternalServer("IDEMPOTENCY_EXECUTOR_NIL", "idempotency executor is nil") + } + + key, err := NormalizeIdempotencyKey(opts.IdempotencyKey) + if err != nil { + return nil, err + } + if key == "" { + if opts.RequireKey && !c.cfg.ObserveOnly { + return nil, ErrIdempotencyKeyRequired + } + data, execErr := execute(ctx) + if execErr != nil { + return nil, execErr + } + return &IdempotencyExecuteResult{Data: data}, nil + } + if c.repo == nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "repo_nil") + return nil, ErrIdempotencyStoreUnavail + } + + if opts.Scope == "" { + return nil, infraerrors.BadRequest("IDEMPOTENCY_SCOPE_REQUIRED", "idempotency scope is required") + } + + fingerprint, err := BuildIdempotencyFingerprint(opts.Method, opts.Route, opts.ActorScope, opts.Payload) + if err != nil { + return nil, err + } + + ttl := opts.TTL + if ttl <= 0 { + ttl = c.cfg.DefaultTTL + } + now := time.Now() + expiresAt := now.Add(ttl) + lockedUntil := now.Add(c.cfg.ProcessingTimeout) + keyHash := HashIdempotencyKey(key) + + record := &IdempotencyRecord{ + Scope: opts.Scope, + IdempotencyKeyHash: keyHash, + RequestFingerprint: fingerprint, + Status: IdempotencyStatusProcessing, + LockedUntil: &lockedUntil, + ExpiresAt: expiresAt, + } + + owner, err := c.repo.CreateProcessing(ctx, record) + if err != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "create_processing_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{ + "operation": "create_processing", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(err) + } + if owner { + recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "new_claim"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "none->processing", false, map[string]string{ + "claim_mode": "new", + }) + } + if !owner { + existing, getErr := c.repo.GetByScopeAndKeyHash(ctx, opts.Scope, keyHash) + if getErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "get_existing_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{ + "operation": "get_existing", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(getErr) + } + if existing == nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "missing_existing") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{ + "operation": "missing_existing", + }) + return nil, ErrIdempotencyStoreUnavail + } + if existing.RequestFingerprint != fingerprint { + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "fingerprint_mismatch"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->fingerprint_mismatch", false, nil) + return nil, ErrIdempotencyKeyConflict + } + reclaimedByExpired := false + if !existing.ExpiresAt.After(now) { + taken, reclaimErr := c.repo.TryReclaim(ctx, existing.ID, existing.Status, now, lockedUntil, expiresAt) + if reclaimErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "try_reclaim_expired_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, existing.Status+"->store_unavailable", false, map[string]string{ + "operation": "try_reclaim_expired", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(reclaimErr) + } + if taken { + reclaimedByExpired = true + recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "expired_reclaim"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, existing.Status+"->processing", false, map[string]string{ + "claim_mode": "expired_reclaim", + }) + record.ID = existing.ID + } else { + latest, latestErr := c.repo.GetByScopeAndKeyHash(ctx, opts.Scope, keyHash) + if latestErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "get_existing_after_expired_reclaim_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{ + "operation": "get_existing_after_expired_reclaim", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(latestErr) + } + if latest == nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "missing_existing_after_expired_reclaim") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{ + "operation": "missing_existing_after_expired_reclaim", + }) + return nil, ErrIdempotencyStoreUnavail + } + if latest.RequestFingerprint != fingerprint { + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "fingerprint_mismatch"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->fingerprint_mismatch", false, nil) + return nil, ErrIdempotencyKeyConflict + } + existing = latest + } + } + + if !reclaimedByExpired { + switch existing.Status { + case IdempotencyStatusSucceeded: + data, parseErr := c.decodeStoredResponse(existing.ResponseBody) + if parseErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "decode_stored_response_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "succeeded->store_unavailable", false, map[string]string{ + "operation": "decode_stored_response", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(parseErr) + } + recordIdempotencyReplay(opts.Route, opts.Scope, nil) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "succeeded->replayed", true, nil) + return &IdempotencyExecuteResult{Data: data, Replayed: true}, nil + case IdempotencyStatusProcessing: + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "in_progress"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->conflict", false, nil) + return nil, c.conflictWithRetryAfter(ErrIdempotencyInProgress, existing.LockedUntil, now) + case IdempotencyStatusFailedRetryable: + if existing.LockedUntil != nil && existing.LockedUntil.After(now) { + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "retry_backoff"}) + recordIdempotencyRetryBackoff(opts.Route, opts.Scope, nil) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->retry_backoff_conflict", false, nil) + return nil, c.conflictWithRetryAfter(ErrIdempotencyRetryBackoff, existing.LockedUntil, now) + } + taken, reclaimErr := c.repo.TryReclaim(ctx, existing.ID, IdempotencyStatusFailedRetryable, now, lockedUntil, expiresAt) + if reclaimErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "try_reclaim_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->store_unavailable", false, map[string]string{ + "operation": "try_reclaim", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(reclaimErr) + } + if !taken { + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "reclaim_race"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->conflict", false, map[string]string{ + "conflict": "reclaim_race", + }) + return nil, c.conflictWithRetryAfter(ErrIdempotencyInProgress, existing.LockedUntil, now) + } + recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "reclaim"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->processing", false, map[string]string{ + "claim_mode": "reclaim", + }) + record.ID = existing.ID + default: + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "unexpected_status"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->conflict", false, map[string]string{ + "status": existing.Status, + }) + return nil, ErrIdempotencyKeyConflict + } + } + } + + if record.ID == 0 { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "record_id_missing") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{ + "operation": "record_id_missing", + }) + return nil, ErrIdempotencyStoreUnavail + } + + execStart := time.Now() + defer func() { + recordIdempotencyProcessingDuration(opts.Route, opts.Scope, time.Since(execStart), nil) + }() + + data, execErr := execute(ctx) + if execErr != nil { + backoffUntil := time.Now().Add(c.cfg.FailedRetryBackoff) + reason := infraerrors.Reason(execErr) + if reason == "" { + reason = "EXECUTION_FAILED" + } + recordIdempotencyRetryBackoff(opts.Route, opts.Scope, nil) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->failed_retryable", false, map[string]string{ + "reason": reason, + }) + if markErr := c.repo.MarkFailedRetryable(ctx, record.ID, reason, backoffUntil, expiresAt); markErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "mark_failed_retryable_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{ + "operation": "mark_failed_retryable", + }) + } + return nil, execErr + } + + storedBody, marshalErr := c.marshalStoredResponse(data) + if marshalErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "marshal_response_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{ + "operation": "marshal_response", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(marshalErr) + } + if markErr := c.repo.MarkSucceeded(ctx, record.ID, 200, storedBody, expiresAt); markErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "mark_succeeded_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{ + "operation": "mark_succeeded", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(markErr) + } + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->succeeded", false, nil) + + return &IdempotencyExecuteResult{Data: data}, nil +} + +func (c *IdempotencyCoordinator) conflictWithRetryAfter(base *infraerrors.ApplicationError, lockedUntil *time.Time, now time.Time) error { + if lockedUntil == nil { + return base + } + sec := int(lockedUntil.Sub(now).Seconds()) + if sec <= 0 { + sec = 1 + } + return base.WithMetadata(map[string]string{"retry_after": strconv.Itoa(sec)}) +} + +func (c *IdempotencyCoordinator) marshalStoredResponse(data any) (string, error) { + raw, err := json.Marshal(data) + if err != nil { + return "", err + } + redacted := logredact.RedactText(string(raw)) + if c.cfg.MaxStoredResponseLen > 0 && len(redacted) > c.cfg.MaxStoredResponseLen { + redacted = redacted[:c.cfg.MaxStoredResponseLen] + "...(truncated)" + } + return redacted, nil +} + +func (c *IdempotencyCoordinator) decodeStoredResponse(stored *string) (any, error) { + if stored == nil || strings.TrimSpace(*stored) == "" { + return map[string]any{}, nil + } + var out any + if err := json.Unmarshal([]byte(*stored), &out); err != nil { + return nil, fmt.Errorf("decode stored response: %w", err) + } + return out, nil +} diff --git a/backend/internal/service/idempotency_cleanup_service.go b/backend/internal/service/idempotency_cleanup_service.go new file mode 100644 index 00000000..aaf6949a --- /dev/null +++ b/backend/internal/service/idempotency_cleanup_service.go @@ -0,0 +1,91 @@ +package service + +import ( + "context" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +// IdempotencyCleanupService 定期清理已过期的幂等记录,避免表无限增长。 +type IdempotencyCleanupService struct { + repo IdempotencyRepository + interval time.Duration + batch int + + startOnce sync.Once + stopOnce sync.Once + stopCh chan struct{} +} + +func NewIdempotencyCleanupService(repo IdempotencyRepository, cfg *config.Config) *IdempotencyCleanupService { + interval := 60 * time.Second + batch := 500 + if cfg != nil { + if cfg.Idempotency.CleanupIntervalSeconds > 0 { + interval = time.Duration(cfg.Idempotency.CleanupIntervalSeconds) * time.Second + } + if cfg.Idempotency.CleanupBatchSize > 0 { + batch = cfg.Idempotency.CleanupBatchSize + } + } + return &IdempotencyCleanupService{ + repo: repo, + interval: interval, + batch: batch, + stopCh: make(chan struct{}), + } +} + +func (s *IdempotencyCleanupService) Start() { + if s == nil || s.repo == nil { + return + } + s.startOnce.Do(func() { + logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] started interval=%s batch=%d", s.interval, s.batch) + go s.runLoop() + }) +} + +func (s *IdempotencyCleanupService) Stop() { + if s == nil { + return + } + s.stopOnce.Do(func() { + close(s.stopCh) + logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] stopped") + }) +} + +func (s *IdempotencyCleanupService) runLoop() { + ticker := time.NewTicker(s.interval) + defer ticker.Stop() + + // 启动后先清理一轮,防止重启后积压。 + s.cleanupOnce() + + for { + select { + case <-ticker.C: + s.cleanupOnce() + case <-s.stopCh: + return + } + } +} + +func (s *IdempotencyCleanupService) cleanupOnce() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + deleted, err := s.repo.DeleteExpired(ctx, time.Now(), s.batch) + if err != nil { + logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] cleanup failed err=%v", err) + return + } + if deleted > 0 { + logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] cleaned expired records count=%d", deleted) + } +} diff --git a/backend/internal/service/idempotency_cleanup_service_test.go b/backend/internal/service/idempotency_cleanup_service_test.go new file mode 100644 index 00000000..556ff364 --- /dev/null +++ b/backend/internal/service/idempotency_cleanup_service_test.go @@ -0,0 +1,69 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type idempotencyCleanupRepoStub struct { + deleteCalls int + lastLimit int + deleteErr error +} + +func (r *idempotencyCleanupRepoStub) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) { + return false, nil +} +func (r *idempotencyCleanupRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) { + return nil, nil +} +func (r *idempotencyCleanupRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + return false, nil +} +func (r *idempotencyCleanupRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, nil +} +func (r *idempotencyCleanupRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error { + return nil +} +func (r *idempotencyCleanupRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return nil +} +func (r *idempotencyCleanupRepoStub) DeleteExpired(_ context.Context, _ time.Time, limit int) (int64, error) { + r.deleteCalls++ + r.lastLimit = limit + if r.deleteErr != nil { + return 0, r.deleteErr + } + return 1, nil +} + +func TestNewIdempotencyCleanupService_UsesConfig(t *testing.T) { + repo := &idempotencyCleanupRepoStub{} + cfg := &config.Config{ + Idempotency: config.IdempotencyConfig{ + CleanupIntervalSeconds: 7, + CleanupBatchSize: 321, + }, + } + svc := NewIdempotencyCleanupService(repo, cfg) + require.Equal(t, 7*time.Second, svc.interval) + require.Equal(t, 321, svc.batch) +} + +func TestIdempotencyCleanupService_CleanupOnce(t *testing.T) { + repo := &idempotencyCleanupRepoStub{} + svc := NewIdempotencyCleanupService(repo, &config.Config{ + Idempotency: config.IdempotencyConfig{ + CleanupBatchSize: 99, + }, + }) + + svc.cleanupOnce() + require.Equal(t, 1, repo.deleteCalls) + require.Equal(t, 99, repo.lastLimit) +} diff --git a/backend/internal/service/idempotency_observability.go b/backend/internal/service/idempotency_observability.go new file mode 100644 index 00000000..f1bf2df2 --- /dev/null +++ b/backend/internal/service/idempotency_observability.go @@ -0,0 +1,171 @@ +package service + +import ( + "sort" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +// IdempotencyMetricsSnapshot 提供幂等核心指标快照(进程内累计)。 +type IdempotencyMetricsSnapshot struct { + ClaimTotal uint64 `json:"claim_total"` + ReplayTotal uint64 `json:"replay_total"` + ConflictTotal uint64 `json:"conflict_total"` + RetryBackoffTotal uint64 `json:"retry_backoff_total"` + ProcessingDurationCount uint64 `json:"processing_duration_count"` + ProcessingDurationTotalMs float64 `json:"processing_duration_total_ms"` + StoreUnavailableTotal uint64 `json:"store_unavailable_total"` +} + +type idempotencyMetrics struct { + claimTotal atomic.Uint64 + replayTotal atomic.Uint64 + conflictTotal atomic.Uint64 + retryBackoffTotal atomic.Uint64 + processingDurationCount atomic.Uint64 + processingDurationMicros atomic.Uint64 + storeUnavailableTotal atomic.Uint64 +} + +var defaultIdempotencyMetrics idempotencyMetrics + +// GetIdempotencyMetricsSnapshot 返回当前幂等指标快照。 +func GetIdempotencyMetricsSnapshot() IdempotencyMetricsSnapshot { + totalMicros := defaultIdempotencyMetrics.processingDurationMicros.Load() + return IdempotencyMetricsSnapshot{ + ClaimTotal: defaultIdempotencyMetrics.claimTotal.Load(), + ReplayTotal: defaultIdempotencyMetrics.replayTotal.Load(), + ConflictTotal: defaultIdempotencyMetrics.conflictTotal.Load(), + RetryBackoffTotal: defaultIdempotencyMetrics.retryBackoffTotal.Load(), + ProcessingDurationCount: defaultIdempotencyMetrics.processingDurationCount.Load(), + ProcessingDurationTotalMs: float64(totalMicros) / 1000.0, + StoreUnavailableTotal: defaultIdempotencyMetrics.storeUnavailableTotal.Load(), + } +} + +func recordIdempotencyClaim(endpoint, scope string, attrs map[string]string) { + defaultIdempotencyMetrics.claimTotal.Add(1) + logIdempotencyMetric("idempotency_claim_total", endpoint, scope, "1", attrs) +} + +func recordIdempotencyReplay(endpoint, scope string, attrs map[string]string) { + defaultIdempotencyMetrics.replayTotal.Add(1) + logIdempotencyMetric("idempotency_replay_total", endpoint, scope, "1", attrs) +} + +func recordIdempotencyConflict(endpoint, scope string, attrs map[string]string) { + defaultIdempotencyMetrics.conflictTotal.Add(1) + logIdempotencyMetric("idempotency_conflict_total", endpoint, scope, "1", attrs) +} + +func recordIdempotencyRetryBackoff(endpoint, scope string, attrs map[string]string) { + defaultIdempotencyMetrics.retryBackoffTotal.Add(1) + logIdempotencyMetric("idempotency_retry_backoff_total", endpoint, scope, "1", attrs) +} + +func recordIdempotencyProcessingDuration(endpoint, scope string, duration time.Duration, attrs map[string]string) { + if duration < 0 { + duration = 0 + } + defaultIdempotencyMetrics.processingDurationCount.Add(1) + defaultIdempotencyMetrics.processingDurationMicros.Add(uint64(duration.Microseconds())) + logIdempotencyMetric("idempotency_processing_duration_ms", endpoint, scope, strconv.FormatFloat(duration.Seconds()*1000, 'f', 3, 64), attrs) +} + +// RecordIdempotencyStoreUnavailable 记录幂等存储不可用事件(用于降级路径观测)。 +func RecordIdempotencyStoreUnavailable(endpoint, scope, strategy string) { + defaultIdempotencyMetrics.storeUnavailableTotal.Add(1) + attrs := map[string]string{} + if strategy != "" { + attrs["strategy"] = strategy + } + logIdempotencyMetric("idempotency_store_unavailable_total", endpoint, scope, "1", attrs) +} + +func logIdempotencyAudit(endpoint, scope, keyHash, stateTransition string, replayed bool, attrs map[string]string) { + var b strings.Builder + builderWriteString(&b, "[IdempotencyAudit]") + builderWriteString(&b, " endpoint=") + builderWriteString(&b, safeAuditField(endpoint)) + builderWriteString(&b, " scope=") + builderWriteString(&b, safeAuditField(scope)) + builderWriteString(&b, " key_hash=") + builderWriteString(&b, safeAuditField(keyHash)) + builderWriteString(&b, " state_transition=") + builderWriteString(&b, safeAuditField(stateTransition)) + builderWriteString(&b, " replayed=") + builderWriteString(&b, strconv.FormatBool(replayed)) + if len(attrs) > 0 { + appendSortedAttrs(&b, attrs) + } + logger.LegacyPrintf("service.idempotency", "%s", b.String()) +} + +func logIdempotencyMetric(name, endpoint, scope, value string, attrs map[string]string) { + var b strings.Builder + builderWriteString(&b, "[IdempotencyMetric]") + builderWriteString(&b, " name=") + builderWriteString(&b, safeAuditField(name)) + builderWriteString(&b, " endpoint=") + builderWriteString(&b, safeAuditField(endpoint)) + builderWriteString(&b, " scope=") + builderWriteString(&b, safeAuditField(scope)) + builderWriteString(&b, " value=") + builderWriteString(&b, safeAuditField(value)) + if len(attrs) > 0 { + appendSortedAttrs(&b, attrs) + } + logger.LegacyPrintf("service.idempotency", "%s", b.String()) +} + +func appendSortedAttrs(builder *strings.Builder, attrs map[string]string) { + if len(attrs) == 0 { + return + } + keys := make([]string, 0, len(attrs)) + for k := range attrs { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + builderWriteByte(builder, ' ') + builderWriteString(builder, k) + builderWriteByte(builder, '=') + builderWriteString(builder, safeAuditField(attrs[k])) + } +} + +func safeAuditField(v string) string { + value := strings.TrimSpace(v) + if value == "" { + return "-" + } + // 日志按 key=value 输出,替换空白避免解析歧义。 + value = strings.ReplaceAll(value, "\n", "_") + value = strings.ReplaceAll(value, "\r", "_") + value = strings.ReplaceAll(value, "\t", "_") + value = strings.ReplaceAll(value, " ", "_") + return value +} + +func resetIdempotencyMetricsForTest() { + defaultIdempotencyMetrics.claimTotal.Store(0) + defaultIdempotencyMetrics.replayTotal.Store(0) + defaultIdempotencyMetrics.conflictTotal.Store(0) + defaultIdempotencyMetrics.retryBackoffTotal.Store(0) + defaultIdempotencyMetrics.processingDurationCount.Store(0) + defaultIdempotencyMetrics.processingDurationMicros.Store(0) + defaultIdempotencyMetrics.storeUnavailableTotal.Store(0) +} + +func builderWriteString(builder *strings.Builder, value string) { + _, _ = builder.WriteString(value) +} + +func builderWriteByte(builder *strings.Builder, value byte) { + _ = builder.WriteByte(value) +} diff --git a/backend/internal/service/idempotency_test.go b/backend/internal/service/idempotency_test.go new file mode 100644 index 00000000..6ff75d1c --- /dev/null +++ b/backend/internal/service/idempotency_test.go @@ -0,0 +1,805 @@ +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +type inMemoryIdempotencyRepo struct { + mu sync.Mutex + nextID int64 + data map[string]*IdempotencyRecord +} + +func newInMemoryIdempotencyRepo() *inMemoryIdempotencyRepo { + return &inMemoryIdempotencyRepo{ + nextID: 1, + data: make(map[string]*IdempotencyRecord), + } +} + +func (r *inMemoryIdempotencyRepo) key(scope, hash string) string { + return scope + "|" + hash +} + +func cloneRecord(in *IdempotencyRecord) *IdempotencyRecord { + if in == nil { + return nil + } + out := *in + if in.ResponseStatus != nil { + v := *in.ResponseStatus + out.ResponseStatus = &v + } + if in.ResponseBody != nil { + v := *in.ResponseBody + out.ResponseBody = &v + } + if in.ErrorReason != nil { + v := *in.ErrorReason + out.ErrorReason = &v + } + if in.LockedUntil != nil { + v := *in.LockedUntil + out.LockedUntil = &v + } + return &out +} + +func (r *inMemoryIdempotencyRepo) CreateProcessing(_ context.Context, record *IdempotencyRecord) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + k := r.key(record.Scope, record.IdempotencyKeyHash) + if _, ok := r.data[k]; ok { + return false, nil + } + rec := cloneRecord(record) + rec.ID = r.nextID + rec.CreatedAt = time.Now() + rec.UpdatedAt = rec.CreatedAt + r.nextID++ + r.data[k] = rec + record.ID = rec.ID + record.CreatedAt = rec.CreatedAt + record.UpdatedAt = rec.UpdatedAt + return true, nil +} + +func (r *inMemoryIdempotencyRepo) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*IdempotencyRecord, error) { + r.mu.Lock() + defer r.mu.Unlock() + return cloneRecord(r.data[r.key(scope, keyHash)]), nil +} + +func (r *inMemoryIdempotencyRepo) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != fromStatus { + return false, nil + } + if rec.LockedUntil != nil && rec.LockedUntil.After(now) { + return false, nil + } + rec.Status = IdempotencyStatusProcessing + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + rec.ErrorReason = nil + rec.UpdatedAt = time.Now() + return true, nil + } + return false, nil +} + +func (r *inMemoryIdempotencyRepo) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint { + return false, nil + } + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + rec.UpdatedAt = time.Now() + return true, nil + } + return false, nil +} + +func (r *inMemoryIdempotencyRepo) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = IdempotencyStatusSucceeded + rec.LockedUntil = nil + rec.ExpiresAt = expiresAt + rec.UpdatedAt = time.Now() + rec.ErrorReason = nil + rec.ResponseStatus = &responseStatus + rec.ResponseBody = &responseBody + return nil + } + return errors.New("record not found") +} + +func (r *inMemoryIdempotencyRepo) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = IdempotencyStatusFailedRetryable + rec.LockedUntil = &lockedUntil + rec.ExpiresAt = expiresAt + rec.UpdatedAt = time.Now() + rec.ErrorReason = &errorReason + return nil + } + return errors.New("record not found") +} + +func (r *inMemoryIdempotencyRepo) DeleteExpired(_ context.Context, now time.Time, _ int) (int64, error) { + r.mu.Lock() + defer r.mu.Unlock() + var deleted int64 + for k, rec := range r.data { + if !rec.ExpiresAt.After(now) { + delete(r.data, k) + deleted++ + } + } + return deleted, nil +} + +func TestIdempotencyCoordinator_RequireKey(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + cfg := DefaultIdempotencyConfig() + cfg.ObserveOnly = false + coordinator := NewIdempotencyCoordinator(repo, cfg) + + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "test.scope", + Method: "POST", + Route: "/test", + ActorScope: "admin:1", + RequireKey: true, + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyKeyRequired)) +} + +func TestIdempotencyCoordinator_ReplaySucceededResult(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + cfg := DefaultIdempotencyConfig() + coordinator := NewIdempotencyCoordinator(repo, cfg) + + execCount := 0 + exec := func(ctx context.Context) (any, error) { + execCount++ + return map[string]any{"count": execCount}, nil + } + + opts := IdempotencyExecuteOptions{ + Scope: "test.scope", + Method: "POST", + Route: "/test", + ActorScope: "user:1", + RequireKey: true, + IdempotencyKey: "case-1", + Payload: map[string]any{"a": 1}, + } + + first, err := coordinator.Execute(context.Background(), opts, exec) + require.NoError(t, err) + require.False(t, first.Replayed) + + second, err := coordinator.Execute(context.Background(), opts, exec) + require.NoError(t, err) + require.True(t, second.Replayed) + require.Equal(t, 1, execCount, "second request should replay without executing business logic") + + metrics := GetIdempotencyMetricsSnapshot() + require.Equal(t, uint64(1), metrics.ClaimTotal) + require.Equal(t, uint64(1), metrics.ReplayTotal) +} + +func TestIdempotencyCoordinator_ReclaimExpiredSucceededRecord(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig()) + + opts := IdempotencyExecuteOptions{ + Scope: "test.scope.expired", + Method: "POST", + Route: "/test/expired", + ActorScope: "user:99", + RequireKey: true, + IdempotencyKey: "expired-case", + Payload: map[string]any{"k": "v"}, + } + + execCount := 0 + exec := func(ctx context.Context) (any, error) { + execCount++ + return map[string]any{"count": execCount}, nil + } + + first, err := coordinator.Execute(context.Background(), opts, exec) + require.NoError(t, err) + require.NotNil(t, first) + require.False(t, first.Replayed) + require.Equal(t, 1, execCount) + + keyHash := HashIdempotencyKey(opts.IdempotencyKey) + repo.mu.Lock() + existing := repo.data[repo.key(opts.Scope, keyHash)] + require.NotNil(t, existing) + existing.ExpiresAt = time.Now().Add(-time.Second) + repo.mu.Unlock() + + second, err := coordinator.Execute(context.Background(), opts, exec) + require.NoError(t, err) + require.NotNil(t, second) + require.False(t, second.Replayed, "expired record should be reclaimed and execute business logic again") + require.Equal(t, 2, execCount) + + third, err := coordinator.Execute(context.Background(), opts, exec) + require.NoError(t, err) + require.NotNil(t, third) + require.True(t, third.Replayed) + payload, ok := third.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, float64(2), payload["count"]) + + metrics := GetIdempotencyMetricsSnapshot() + require.GreaterOrEqual(t, metrics.ClaimTotal, uint64(2)) + require.GreaterOrEqual(t, metrics.ReplayTotal, uint64(1)) +} + +func TestIdempotencyCoordinator_SameKeyDifferentPayloadConflict(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + cfg := DefaultIdempotencyConfig() + coordinator := NewIdempotencyCoordinator(repo, cfg) + + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "test.scope", + Method: "POST", + Route: "/test", + ActorScope: "user:1", + RequireKey: true, + IdempotencyKey: "case-2", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.NoError(t, err) + + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "test.scope", + Method: "POST", + Route: "/test", + ActorScope: "user:1", + RequireKey: true, + IdempotencyKey: "case-2", + Payload: map[string]any{"a": 2}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyKeyConflict)) + + metrics := GetIdempotencyMetricsSnapshot() + require.Equal(t, uint64(1), metrics.ConflictTotal) +} + +func TestIdempotencyCoordinator_BackoffAfterRetryableFailure(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + cfg := DefaultIdempotencyConfig() + cfg.FailedRetryBackoff = 2 * time.Second + coordinator := NewIdempotencyCoordinator(repo, cfg) + + opts := IdempotencyExecuteOptions{ + Scope: "test.scope", + Method: "POST", + Route: "/test", + ActorScope: "user:1", + RequireKey: true, + IdempotencyKey: "case-3", + Payload: map[string]any{"a": 1}, + } + + _, err := coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) { + return nil, infraerrors.InternalServer("UPSTREAM_ERROR", "upstream error") + }) + require.Error(t, err) + + _, err = coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyRetryBackoff)) + require.Greater(t, RetryAfterSecondsFromError(err), 0) + + metrics := GetIdempotencyMetricsSnapshot() + require.GreaterOrEqual(t, metrics.RetryBackoffTotal, uint64(2)) + require.GreaterOrEqual(t, metrics.ConflictTotal, uint64(1)) + require.GreaterOrEqual(t, metrics.ProcessingDurationCount, uint64(1)) +} + +func TestIdempotencyCoordinator_ConcurrentSameKeySingleSideEffect(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + cfg := DefaultIdempotencyConfig() + cfg.ProcessingTimeout = 2 * time.Second + coordinator := NewIdempotencyCoordinator(repo, cfg) + + opts := IdempotencyExecuteOptions{ + Scope: "test.scope.concurrent", + Method: "POST", + Route: "/test/concurrent", + ActorScope: "user:7", + RequireKey: true, + IdempotencyKey: "concurrent-case", + Payload: map[string]any{"v": 1}, + } + + var execCount int32 + var wg sync.WaitGroup + for i := 0; i < 8; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, _ = coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) { + atomic.AddInt32(&execCount, 1) + time.Sleep(80 * time.Millisecond) + return map[string]any{"ok": true}, nil + }) + }() + } + wg.Wait() + + replayed, err := coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) { + atomic.AddInt32(&execCount, 1) + return map[string]any{"ok": true}, nil + }) + require.NoError(t, err) + require.True(t, replayed.Replayed) + require.Equal(t, int32(1), atomic.LoadInt32(&execCount), "concurrent same-key requests should execute business side-effect once") + + metrics := GetIdempotencyMetricsSnapshot() + require.Equal(t, uint64(1), metrics.ClaimTotal) + require.Equal(t, uint64(1), metrics.ReplayTotal) + require.GreaterOrEqual(t, metrics.ConflictTotal, uint64(1)) +} + +type failingIdempotencyRepo struct{} + +func (failingIdempotencyRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) { + return false, errors.New("store unavailable") +} +func (failingIdempotencyRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) { + return nil, errors.New("store unavailable") +} +func (failingIdempotencyRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (failingIdempotencyRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (failingIdempotencyRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error { + return errors.New("store unavailable") +} +func (failingIdempotencyRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return errors.New("store unavailable") +} +func (failingIdempotencyRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) { + return 0, errors.New("store unavailable") +} + +func TestIdempotencyCoordinator_StoreUnavailableMetrics(t *testing.T) { + resetIdempotencyMetricsForTest() + coordinator := NewIdempotencyCoordinator(failingIdempotencyRepo{}, DefaultIdempotencyConfig()) + + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "test.scope.unavailable", + Method: "POST", + Route: "/test/unavailable", + ActorScope: "admin:1", + RequireKey: true, + IdempotencyKey: "case-unavailable", + Payload: map[string]any{"v": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + require.GreaterOrEqual(t, GetIdempotencyMetricsSnapshot().StoreUnavailableTotal, uint64(1)) +} + +func TestDefaultIdempotencyCoordinatorAndTTLs(t *testing.T) { + SetDefaultIdempotencyCoordinator(nil) + require.Nil(t, DefaultIdempotencyCoordinator()) + require.Equal(t, DefaultIdempotencyConfig().DefaultTTL, DefaultWriteIdempotencyTTL()) + require.Equal(t, DefaultIdempotencyConfig().SystemOperationTTL, DefaultSystemOperationIdempotencyTTL()) + + coordinator := NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), IdempotencyConfig{ + DefaultTTL: 2 * time.Hour, + SystemOperationTTL: 15 * time.Minute, + ProcessingTimeout: 10 * time.Second, + FailedRetryBackoff: 3 * time.Second, + ObserveOnly: false, + }) + SetDefaultIdempotencyCoordinator(coordinator) + t.Cleanup(func() { + SetDefaultIdempotencyCoordinator(nil) + }) + + require.Same(t, coordinator, DefaultIdempotencyCoordinator()) + require.Equal(t, 2*time.Hour, DefaultWriteIdempotencyTTL()) + require.Equal(t, 15*time.Minute, DefaultSystemOperationIdempotencyTTL()) +} + +func TestNormalizeIdempotencyKeyAndFingerprint(t *testing.T) { + key, err := NormalizeIdempotencyKey(" abc-123 ") + require.NoError(t, err) + require.Equal(t, "abc-123", key) + + key, err = NormalizeIdempotencyKey("") + require.NoError(t, err) + require.Equal(t, "", key) + + _, err = NormalizeIdempotencyKey(string(make([]byte, 129))) + require.Error(t, err) + + _, err = NormalizeIdempotencyKey("bad\nkey") + require.Error(t, err) + + fp1, err := BuildIdempotencyFingerprint("", "", "", map[string]any{"a": 1}) + require.NoError(t, err) + require.NotEmpty(t, fp1) + fp2, err := BuildIdempotencyFingerprint("POST", "/", "anonymous", map[string]any{"a": 1}) + require.NoError(t, err) + require.Equal(t, fp1, fp2) + + _, err = BuildIdempotencyFingerprint("POST", "/x", "u:1", map[string]any{"bad": make(chan int)}) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyInvalidPayload), infraerrors.Code(err)) +} + +func TestRetryAfterSecondsFromErrorBranches(t *testing.T) { + require.Equal(t, 0, RetryAfterSecondsFromError(nil)) + require.Equal(t, 0, RetryAfterSecondsFromError(errors.New("plain"))) + + err := ErrIdempotencyInProgress.WithMetadata(map[string]string{"retry_after": "12"}) + require.Equal(t, 12, RetryAfterSecondsFromError(err)) + + err = ErrIdempotencyInProgress.WithMetadata(map[string]string{"retry_after": "bad"}) + require.Equal(t, 0, RetryAfterSecondsFromError(err)) +} + +func TestIdempotencyCoordinator_ExecuteNilExecutorAndNoKeyPassThrough(t *testing.T) { + repo := newInMemoryIdempotencyRepo() + coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig()) + + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Payload: map[string]any{"a": 1}, + }, nil) + require.Error(t, err) + require.Equal(t, "IDEMPOTENCY_EXECUTOR_NIL", infraerrors.Reason(err)) + + called := 0 + result, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + RequireKey: true, + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + called++ + return map[string]any{"ok": true}, nil + }) + require.NoError(t, err) + require.Equal(t, 1, called) + require.NotNil(t, result) + require.False(t, result.Replayed) +} + +type noIDOwnerRepo struct{} + +func (noIDOwnerRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) { + return true, nil +} +func (noIDOwnerRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) { + return nil, nil +} +func (noIDOwnerRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + return false, nil +} +func (noIDOwnerRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, nil +} +func (noIDOwnerRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error { return nil } +func (noIDOwnerRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return nil +} +func (noIDOwnerRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) { return 0, nil } + +func TestIdempotencyCoordinator_RepoNilScopeRequiredAndRecordIDMissing(t *testing.T) { + cfg := DefaultIdempotencyConfig() + coordinator := NewIdempotencyCoordinator(nil, cfg) + + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + coordinator = NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), cfg) + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + IdempotencyKey: "k2", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, "IDEMPOTENCY_SCOPE_REQUIRED", infraerrors.Reason(err)) + + coordinator = NewIdempotencyCoordinator(noIDOwnerRepo{}, cfg) + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope-no-id", + IdempotencyKey: "k3", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) +} + +type conflictBranchRepo struct { + existing *IdempotencyRecord + tryReclaimErr error + tryReclaimOK bool +} + +func (r *conflictBranchRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) { + return false, nil +} +func (r *conflictBranchRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) { + return cloneRecord(r.existing), nil +} +func (r *conflictBranchRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + if r.tryReclaimErr != nil { + return false, r.tryReclaimErr + } + return r.tryReclaimOK, nil +} +func (r *conflictBranchRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, nil +} +func (r *conflictBranchRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error { + return nil +} +func (r *conflictBranchRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return nil +} +func (r *conflictBranchRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) { + return 0, nil +} + +func TestIdempotencyCoordinator_ConflictBranchesAndDecodeError(t *testing.T) { + now := time.Now() + fp, err := BuildIdempotencyFingerprint("POST", "/x", "u:1", map[string]any{"a": 1}) + require.NoError(t, err) + badBody := "{bad-json" + repo := &conflictBranchRepo{ + existing: &IdempotencyRecord{ + ID: 1, + Scope: "scope", + IdempotencyKeyHash: HashIdempotencyKey("k"), + RequestFingerprint: fp, + Status: IdempotencyStatusSucceeded, + ResponseBody: &badBody, + ExpiresAt: now.Add(time.Hour), + }, + } + coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig()) + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Method: "POST", + Route: "/x", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + repo.existing = &IdempotencyRecord{ + ID: 2, + Scope: "scope", + IdempotencyKeyHash: HashIdempotencyKey("k"), + RequestFingerprint: fp, + Status: "unknown", + ExpiresAt: now.Add(time.Hour), + } + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Method: "POST", + Route: "/x", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyKeyConflict), infraerrors.Code(err)) + + repo.existing = &IdempotencyRecord{ + ID: 3, + Scope: "scope", + IdempotencyKeyHash: HashIdempotencyKey("k"), + RequestFingerprint: fp, + Status: IdempotencyStatusFailedRetryable, + LockedUntil: ptrTime(now.Add(-time.Second)), + ExpiresAt: now.Add(time.Hour), + } + repo.tryReclaimErr = errors.New("reclaim down") + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Method: "POST", + Route: "/x", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + repo.tryReclaimErr = nil + repo.tryReclaimOK = false + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Method: "POST", + Route: "/x", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyInProgress), infraerrors.Code(err)) +} + +type markBehaviorRepo struct { + inMemoryIdempotencyRepo + failMarkSucceeded bool + failMarkFailed bool +} + +func (r *markBehaviorRepo) MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error { + if r.failMarkSucceeded { + return errors.New("mark succeeded failed") + } + return r.inMemoryIdempotencyRepo.MarkSucceeded(ctx, id, responseStatus, responseBody, expiresAt) +} + +func (r *markBehaviorRepo) MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error { + if r.failMarkFailed { + return errors.New("mark failed retryable failed") + } + return r.inMemoryIdempotencyRepo.MarkFailedRetryable(ctx, id, errorReason, lockedUntil, expiresAt) +} + +func TestIdempotencyCoordinator_MarkAndMarshalBranches(t *testing.T) { + repo := &markBehaviorRepo{inMemoryIdempotencyRepo: *newInMemoryIdempotencyRepo()} + coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig()) + + repo.failMarkSucceeded = true + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope-success", + IdempotencyKey: "k1", + Method: "POST", + Route: "/ok", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + repo.failMarkSucceeded = false + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope-marshal", + IdempotencyKey: "k2", + Method: "POST", + Route: "/bad", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"bad": make(chan int)}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + repo.failMarkFailed = true + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope-fail", + IdempotencyKey: "k3", + Method: "POST", + Route: "/fail", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return nil, errors.New("plain failure") + }) + require.Error(t, err) + require.Equal(t, "plain failure", err.Error()) +} + +func TestIdempotencyCoordinator_HelperBranches(t *testing.T) { + c := NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), IdempotencyConfig{ + DefaultTTL: time.Hour, + SystemOperationTTL: time.Hour, + ProcessingTimeout: time.Second, + FailedRetryBackoff: time.Second, + MaxStoredResponseLen: 12, + ObserveOnly: false, + }) + + // conflictWithRetryAfter without locked_until should return base error. + base := ErrIdempotencyInProgress + err := c.conflictWithRetryAfter(base, nil, time.Now()) + require.Equal(t, infraerrors.Code(base), infraerrors.Code(err)) + + // marshalStoredResponse should truncate. + body, err := c.marshalStoredResponse(map[string]any{"long": "abcdefghijklmnopqrstuvwxyz"}) + require.NoError(t, err) + require.Contains(t, body, "...(truncated)") + + // decodeStoredResponse empty and invalid json. + out, err := c.decodeStoredResponse(nil) + require.NoError(t, err) + _, ok := out.(map[string]any) + require.True(t, ok) + + invalid := "{invalid" + _, err = c.decodeStoredResponse(&invalid) + require.Error(t, err) +} diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index 261da0ef..f3130c91 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -7,13 +7,14 @@ import ( "encoding/hex" "encoding/json" "fmt" - "log" "log/slog" "net/http" "regexp" "strconv" "strings" "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) // 预编译正则表达式(避免每次调用重新编译) @@ -45,6 +46,7 @@ type Fingerprint struct { StainlessArch string StainlessRuntime string StainlessRuntimeVersion string + UpdatedAt int64 `json:",omitempty"` // Unix timestamp,用于判断是否需要续期TTL } // IdentityCache defines cache operations for identity service @@ -77,14 +79,26 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID // 尝试从缓存获取指纹 cached, err := s.cache.GetFingerprint(ctx, accountID) if err == nil && cached != nil { + needWrite := false + // 检查客户端的user-agent是否是更新版本 clientUA := headers.Get("User-Agent") if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) { - // 更新user-agent - cached.UserAgent = clientUA - // 保存更新后的指纹 - _ = s.cache.SetFingerprint(ctx, accountID, cached) - log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA) + // 版本升级:merge 语义 — 仅更新请求中实际携带的字段,保留缓存值 + // 避免缺失的头被硬编码默认值覆盖(如新 CLI 版本 + 旧 SDK 默认值的不一致) + mergeHeadersIntoFingerprint(cached, headers) + needWrite = true + logger.LegacyPrintf("service.identity", "Updated fingerprint for account %d: %s (merge update)", accountID, clientUA) + } else if time.Since(time.Unix(cached.UpdatedAt, 0)) > 24*time.Hour { + // 距上次写入超过24小时,续期TTL + needWrite = true + } + + if needWrite { + cached.UpdatedAt = time.Now().Unix() + if err := s.cache.SetFingerprint(ctx, accountID, cached); err != nil { + logger.LegacyPrintf("service.identity", "Warning: failed to refresh fingerprint for account %d: %v", accountID, err) + } } return cached, nil } @@ -94,13 +108,14 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID // 生成随机ClientID fp.ClientID = generateClientID() + fp.UpdatedAt = time.Now().Unix() - // 保存到缓存(永不过期) + // 保存到缓存(7天TTL,每24小时自动续期) if err := s.cache.SetFingerprint(ctx, accountID, fp); err != nil { - log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err) + logger.LegacyPrintf("service.identity", "Warning: failed to cache fingerprint for account %d: %v", accountID, err) } - log.Printf("Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID) + logger.LegacyPrintf("service.identity", "Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID) return fp, nil } @@ -126,6 +141,31 @@ func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fin return fp } +// mergeHeadersIntoFingerprint 将请求头中实际存在的字段合并到现有指纹中(用于版本升级场景) +// 关键语义:请求中有的字段 → 用新值覆盖;缺失的头 → 保留缓存中的已有值 +// 与 createFingerprintFromHeaders 的区别:后者用于首次创建,缺失头回退到 defaultFingerprint; +// 本函数用于升级更新,缺失头保留缓存值,避免将已知的真实值退化为硬编码默认值 +func mergeHeadersIntoFingerprint(fp *Fingerprint, headers http.Header) { + // User-Agent:版本升级的触发条件,一定存在 + if ua := headers.Get("User-Agent"); ua != "" { + fp.UserAgent = ua + } + // X-Stainless-* 头:仅在请求中实际携带时才更新,否则保留缓存值 + mergeHeader(headers, "X-Stainless-Lang", &fp.StainlessLang) + mergeHeader(headers, "X-Stainless-Package-Version", &fp.StainlessPackageVersion) + mergeHeader(headers, "X-Stainless-OS", &fp.StainlessOS) + mergeHeader(headers, "X-Stainless-Arch", &fp.StainlessArch) + mergeHeader(headers, "X-Stainless-Runtime", &fp.StainlessRuntime) + mergeHeader(headers, "X-Stainless-Runtime-Version", &fp.StainlessRuntimeVersion) +} + +// mergeHeader 如果请求头中存在该字段则更新目标值,否则保留原值 +func mergeHeader(headers http.Header, key string, target *string) { + if v := headers.Get(key); v != "" { + *target = v + } +} + // getHeaderOrDefault 获取header值,如果不存在则返回默认值 func getHeaderOrDefault(headers http.Header, key, defaultValue string) string { if v := headers.Get(key); v != "" { @@ -277,19 +317,19 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b // 获取或生成固定的伪装 session ID maskedSessionID, err := s.cache.GetMaskedSessionID(ctx, account.ID) if err != nil { - log.Printf("Warning: failed to get masked session ID for account %d: %v", account.ID, err) + logger.LegacyPrintf("service.identity", "Warning: failed to get masked session ID for account %d: %v", account.ID, err) return newBody, nil } if maskedSessionID == "" { // 首次或已过期,生成新的伪装 session ID maskedSessionID = generateRandomUUID() - log.Printf("Generated new masked session ID for account %d: %s", account.ID, maskedSessionID) + logger.LegacyPrintf("service.identity", "Generated new masked session ID for account %d: %s", account.ID, maskedSessionID) } // 刷新 TTL(每次请求都刷新,保持 15 分钟有效期) if err := s.cache.SetMaskedSessionID(ctx, account.ID, maskedSessionID); err != nil { - log.Printf("Warning: failed to set masked session ID for account %d: %v", account.ID, err) + logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err) } // 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容 @@ -335,7 +375,7 @@ func generateClientID() string { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { // 极罕见的情况,使用时间戳+固定值作为fallback - log.Printf("Warning: crypto/rand.Read failed: %v, using fallback", err) + logger.LegacyPrintf("service.identity", "Warning: crypto/rand.Read failed: %v, using fallback", err) // 使用SHA256(当前纳秒时间)作为fallback h := sha256.Sum256([]byte(fmt.Sprintf("%d", time.Now().UnixNano()))) return hex.EncodeToString(h[:]) @@ -370,8 +410,25 @@ func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) { return major, minor, patch, true } +// extractProduct 提取 User-Agent 中 "/" 前的产品名 +// 例如:claude-cli/2.1.22 (external, cli) -> "claude-cli" +func extractProduct(ua string) string { + if idx := strings.Index(ua, "/"); idx > 0 { + return strings.ToLower(ua[:idx]) + } + return "" +} + // isNewerVersion 比较版本号,判断newUA是否比cachedUA更新 +// 要求产品名一致(防止浏览器 UA 如 Mozilla/5.0 误判为更新版本) func isNewerVersion(newUA, cachedUA string) bool { + // 校验产品名一致性 + newProduct := extractProduct(newUA) + cachedProduct := extractProduct(cachedUA) + if newProduct == "" || cachedProduct == "" || newProduct != cachedProduct { + return false + } + newMajor, newMinor, newPatch, newOk := parseUserAgentVersion(newUA) cachedMajor, cachedMinor, cachedPatch, cachedOk := parseUserAgentVersion(cachedUA) diff --git a/backend/internal/service/model_rate_limit.go b/backend/internal/service/model_rate_limit.go index ff4b5977..c45615cc 100644 --- a/backend/internal/service/model_rate_limit.go +++ b/backend/internal/service/model_rate_limit.go @@ -4,8 +4,6 @@ import ( "context" "strings" "time" - - "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" ) const modelRateLimitsKey = "model_rate_limits" @@ -73,7 +71,7 @@ func resolveFinalAntigravityModelKey(ctx context.Context, account *Account, requ return "" } // thinking 会影响 Antigravity 最终模型名(例如 claude-sonnet-4-5 -> claude-sonnet-4-5-thinking) - if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok { + if enabled, ok := ThinkingEnabledFromContext(ctx); ok { modelKey = applyThinkingModelSuffix(modelKey, enabled) } return modelKey diff --git a/backend/internal/service/oauth_service.go b/backend/internal/service/oauth_service.go index 15543080..0931f9ce 100644 --- a/backend/internal/service/oauth_service.go +++ b/backend/internal/service/oauth_service.go @@ -12,8 +12,9 @@ import ( // OpenAIOAuthClient interface for OpenAI OAuth operations type OpenAIOAuthClient interface { - ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) + ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) + RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) } // ClaudeOAuthClient handles HTTP requests for Claude OAuth flows @@ -217,7 +218,7 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) ( // Ensure org_uuid is set (from step 1 if not from token response) if tokenInfo.OrgUUID == "" && orgUUID != "" { tokenInfo.OrgUUID = orgUUID - log.Printf("[OAuth] Set org_uuid from cookie auth: %s", orgUUID) + log.Printf("[OAuth] Set org_uuid from cookie auth") } return tokenInfo, nil @@ -251,16 +252,16 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif if tokenResp.Organization != nil && tokenResp.Organization.UUID != "" { tokenInfo.OrgUUID = tokenResp.Organization.UUID - log.Printf("[OAuth] Got org_uuid: %s", tokenInfo.OrgUUID) + log.Printf("[OAuth] Got org_uuid") } if tokenResp.Account != nil { if tokenResp.Account.UUID != "" { tokenInfo.AccountUUID = tokenResp.Account.UUID - log.Printf("[OAuth] Got account_uuid: %s", tokenInfo.AccountUUID) + log.Printf("[OAuth] Got account_uuid") } if tokenResp.Account.EmailAddress != "" { tokenInfo.EmailAddress = tokenResp.Account.EmailAddress - log.Printf("[OAuth] Got email_address: %s", tokenInfo.EmailAddress) + log.Printf("[OAuth] Got email_address") } } diff --git a/backend/internal/service/oauth_service_test.go b/backend/internal/service/oauth_service_test.go new file mode 100644 index 00000000..78f39dc5 --- /dev/null +++ b/backend/internal/service/oauth_service_test.go @@ -0,0 +1,607 @@ +//go:build unit + +package service + +import ( + "context" + "fmt" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +// --- mock: ClaudeOAuthClient --- + +type mockClaudeOAuthClient struct { + getOrgUUIDFunc func(ctx context.Context, sessionKey, proxyURL string) (string, error) + getAuthCodeFunc func(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) + exchangeCodeFunc func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) + refreshTokenFunc func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) +} + +func (m *mockClaudeOAuthClient) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) { + if m.getOrgUUIDFunc != nil { + return m.getOrgUUIDFunc(ctx, sessionKey, proxyURL) + } + panic("GetOrganizationUUID not implemented") +} + +func (m *mockClaudeOAuthClient) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) { + if m.getAuthCodeFunc != nil { + return m.getAuthCodeFunc(ctx, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL) + } + panic("GetAuthorizationCode not implemented") +} + +func (m *mockClaudeOAuthClient) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { + if m.exchangeCodeFunc != nil { + return m.exchangeCodeFunc(ctx, code, codeVerifier, state, proxyURL, isSetupToken) + } + panic("ExchangeCodeForToken not implemented") +} + +func (m *mockClaudeOAuthClient) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + if m.refreshTokenFunc != nil { + return m.refreshTokenFunc(ctx, refreshToken, proxyURL) + } + panic("RefreshToken not implemented") +} + +// --- mock: ProxyRepository (最小实现,仅覆盖 OAuthService 依赖的方法) --- + +type mockProxyRepoForOAuth struct { + getByIDFunc func(ctx context.Context, id int64) (*Proxy, error) +} + +func (m *mockProxyRepoForOAuth) Create(ctx context.Context, proxy *Proxy) error { + panic("Create not implemented") +} +func (m *mockProxyRepoForOAuth) GetByID(ctx context.Context, id int64) (*Proxy, error) { + if m.getByIDFunc != nil { + return m.getByIDFunc(ctx, id) + } + return nil, fmt.Errorf("proxy not found") +} +func (m *mockProxyRepoForOAuth) ListByIDs(ctx context.Context, ids []int64) ([]Proxy, error) { + panic("ListByIDs not implemented") +} +func (m *mockProxyRepoForOAuth) Update(ctx context.Context, proxy *Proxy) error { + panic("Update not implemented") +} +func (m *mockProxyRepoForOAuth) Delete(ctx context.Context, id int64) error { + panic("Delete not implemented") +} +func (m *mockProxyRepoForOAuth) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) { + panic("List not implemented") +} +func (m *mockProxyRepoForOAuth) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) { + panic("ListWithFilters not implemented") +} +func (m *mockProxyRepoForOAuth) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) { + panic("ListWithFiltersAndAccountCount not implemented") +} +func (m *mockProxyRepoForOAuth) ListActive(ctx context.Context) ([]Proxy, error) { + panic("ListActive not implemented") +} +func (m *mockProxyRepoForOAuth) ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) { + panic("ListActiveWithAccountCount not implemented") +} +func (m *mockProxyRepoForOAuth) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) { + panic("ExistsByHostPortAuth not implemented") +} +func (m *mockProxyRepoForOAuth) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { + panic("CountAccountsByProxyID not implemented") +} +func (m *mockProxyRepoForOAuth) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) { + panic("ListAccountSummariesByProxyID not implemented") +} + +// ===================== +// 测试用例 +// ===================== + +func TestNewOAuthService(t *testing.T) { + t.Parallel() + + proxyRepo := &mockProxyRepoForOAuth{} + client := &mockClaudeOAuthClient{} + svc := NewOAuthService(proxyRepo, client) + + if svc == nil { + t.Fatal("NewOAuthService 返回 nil") + } + if svc.proxyRepo != proxyRepo { + t.Fatal("proxyRepo 未正确设置") + } + if svc.oauthClient != client { + t.Fatal("oauthClient 未正确设置") + } + if svc.sessionStore == nil { + t.Fatal("sessionStore 应被自动初始化") + } + + // 清理 + svc.Stop() +} + +func TestOAuthService_GenerateAuthURL(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + defer svc.Stop() + + result, err := svc.GenerateAuthURL(context.Background(), nil) + if err != nil { + t.Fatalf("GenerateAuthURL 返回错误: %v", err) + } + if result == nil { + t.Fatal("GenerateAuthURL 返回 nil") + } + if result.AuthURL == "" { + t.Fatal("AuthURL 为空") + } + if result.SessionID == "" { + t.Fatal("SessionID 为空") + } + + // 验证 session 已存储 + session, ok := svc.sessionStore.Get(result.SessionID) + if !ok { + t.Fatal("session 未在 sessionStore 中找到") + } + if session.Scope != oauth.ScopeOAuth { + t.Fatalf("scope 不匹配: got=%q want=%q", session.Scope, oauth.ScopeOAuth) + } +} + +func TestOAuthService_GenerateAuthURL_WithProxy(t *testing.T) { + t.Parallel() + + proxyRepo := &mockProxyRepoForOAuth{ + getByIDFunc: func(ctx context.Context, id int64) (*Proxy, error) { + return &Proxy{ + ID: 1, + Protocol: "http", + Host: "proxy.example.com", + Port: 8080, + }, nil + }, + } + svc := NewOAuthService(proxyRepo, &mockClaudeOAuthClient{}) + defer svc.Stop() + + proxyID := int64(1) + result, err := svc.GenerateAuthURL(context.Background(), &proxyID) + if err != nil { + t.Fatalf("GenerateAuthURL 返回错误: %v", err) + } + + session, ok := svc.sessionStore.Get(result.SessionID) + if !ok { + t.Fatal("session 未在 sessionStore 中找到") + } + if session.ProxyURL != "http://proxy.example.com:8080" { + t.Fatalf("ProxyURL 不匹配: got=%q", session.ProxyURL) + } +} + +func TestOAuthService_GenerateSetupTokenURL(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + defer svc.Stop() + + result, err := svc.GenerateSetupTokenURL(context.Background(), nil) + if err != nil { + t.Fatalf("GenerateSetupTokenURL 返回错误: %v", err) + } + if result == nil { + t.Fatal("GenerateSetupTokenURL 返回 nil") + } + + // 验证 scope 是 inference + session, ok := svc.sessionStore.Get(result.SessionID) + if !ok { + t.Fatal("session 未在 sessionStore 中找到") + } + if session.Scope != oauth.ScopeInference { + t.Fatalf("scope 不匹配: got=%q want=%q", session.Scope, oauth.ScopeInference) + } +} + +func TestOAuthService_ExchangeCode_SessionNotFound(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + defer svc.Stop() + + _, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{ + SessionID: "nonexistent-session", + Code: "test-code", + }) + if err == nil { + t.Fatal("ExchangeCode 应返回错误(session 不存在)") + } + if err.Error() != "session not found or expired" { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestOAuthService_ExchangeCode_Success(t *testing.T) { + t.Parallel() + + exchangeCalled := false + client := &mockClaudeOAuthClient{ + exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { + exchangeCalled = true + if code != "auth-code-123" { + t.Errorf("code 不匹配: got=%q", code) + } + if isSetupToken { + t.Error("isSetupToken 应为 false(ScopeOAuth)") + } + return &oauth.TokenResponse{ + AccessToken: "access-token-abc", + TokenType: "Bearer", + ExpiresIn: 3600, + RefreshToken: "refresh-token-xyz", + Scope: oauth.ScopeOAuth, + Organization: &oauth.OrgInfo{UUID: "org-uuid-111"}, + Account: &oauth.AccountInfo{UUID: "acc-uuid-222", EmailAddress: "test@example.com"}, + }, nil + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + // 先生成 URL 以创建 session + result, err := svc.GenerateAuthURL(context.Background(), nil) + if err != nil { + t.Fatalf("GenerateAuthURL 返回错误: %v", err) + } + + // 交换 code + tokenInfo, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{ + SessionID: result.SessionID, + Code: "auth-code-123", + }) + if err != nil { + t.Fatalf("ExchangeCode 返回错误: %v", err) + } + + if !exchangeCalled { + t.Fatal("ExchangeCodeForToken 未被调用") + } + if tokenInfo.AccessToken != "access-token-abc" { + t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken) + } + if tokenInfo.TokenType != "Bearer" { + t.Fatalf("TokenType 不匹配: got=%q", tokenInfo.TokenType) + } + if tokenInfo.RefreshToken != "refresh-token-xyz" { + t.Fatalf("RefreshToken 不匹配: got=%q", tokenInfo.RefreshToken) + } + if tokenInfo.OrgUUID != "org-uuid-111" { + t.Fatalf("OrgUUID 不匹配: got=%q", tokenInfo.OrgUUID) + } + if tokenInfo.AccountUUID != "acc-uuid-222" { + t.Fatalf("AccountUUID 不匹配: got=%q", tokenInfo.AccountUUID) + } + if tokenInfo.EmailAddress != "test@example.com" { + t.Fatalf("EmailAddress 不匹配: got=%q", tokenInfo.EmailAddress) + } + if tokenInfo.ExpiresIn != 3600 { + t.Fatalf("ExpiresIn 不匹配: got=%d", tokenInfo.ExpiresIn) + } + if tokenInfo.ExpiresAt == 0 { + t.Fatal("ExpiresAt 不应为 0") + } + + // 验证 session 已被删除 + _, ok := svc.sessionStore.Get(result.SessionID) + if ok { + t.Fatal("session 应在交换成功后被删除") + } +} + +func TestOAuthService_ExchangeCode_SetupToken(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { + if !isSetupToken { + t.Error("isSetupToken 应为 true(ScopeInference)") + } + return &oauth.TokenResponse{ + AccessToken: "setup-token", + TokenType: "Bearer", + ExpiresIn: 3600, + Scope: oauth.ScopeInference, + }, nil + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + // 使用 SetupToken URL(inference scope) + result, err := svc.GenerateSetupTokenURL(context.Background(), nil) + if err != nil { + t.Fatalf("GenerateSetupTokenURL 返回错误: %v", err) + } + + tokenInfo, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{ + SessionID: result.SessionID, + Code: "setup-code", + }) + if err != nil { + t.Fatalf("ExchangeCode 返回错误: %v", err) + } + if tokenInfo.AccessToken != "setup-token" { + t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken) + } +} + +func TestOAuthService_ExchangeCode_ClientError(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { + return nil, fmt.Errorf("upstream error: invalid code") + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + result, _ := svc.GenerateAuthURL(context.Background(), nil) + _, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{ + SessionID: result.SessionID, + Code: "bad-code", + }) + if err == nil { + t.Fatal("ExchangeCode 应返回错误") + } + if err.Error() != "upstream error: invalid code" { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestOAuthService_RefreshToken(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + if refreshToken != "my-refresh-token" { + t.Errorf("refreshToken 不匹配: got=%q", refreshToken) + } + if proxyURL != "" { + t.Errorf("proxyURL 应为空: got=%q", proxyURL) + } + return &oauth.TokenResponse{ + AccessToken: "new-access-token", + TokenType: "Bearer", + ExpiresIn: 7200, + RefreshToken: "new-refresh-token", + Scope: oauth.ScopeOAuth, + }, nil + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + tokenInfo, err := svc.RefreshToken(context.Background(), "my-refresh-token", "") + if err != nil { + t.Fatalf("RefreshToken 返回错误: %v", err) + } + if tokenInfo.AccessToken != "new-access-token" { + t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken) + } + if tokenInfo.RefreshToken != "new-refresh-token" { + t.Fatalf("RefreshToken 不匹配: got=%q", tokenInfo.RefreshToken) + } + if tokenInfo.ExpiresIn != 7200 { + t.Fatalf("ExpiresIn 不匹配: got=%d", tokenInfo.ExpiresIn) + } + if tokenInfo.ExpiresAt == 0 { + t.Fatal("ExpiresAt 不应为 0") + } +} + +func TestOAuthService_RefreshToken_Error(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + return nil, fmt.Errorf("invalid_grant: token expired") + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + _, err := svc.RefreshToken(context.Background(), "expired-token", "") + if err == nil { + t.Fatal("RefreshToken 应返回错误") + } +} + +func TestOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + defer svc.Stop() + + // 无 refresh_token 的账号 + account := &Account{ + ID: 1, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "some-token", + }, + } + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("RefreshAccountToken 应返回错误(无 refresh_token)") + } + if err.Error() != "no refresh token available" { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestOAuthService_RefreshAccountToken_EmptyRefreshToken(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + defer svc.Stop() + + account := &Account{ + ID: 2, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "some-token", + "refresh_token": "", + }, + } + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("RefreshAccountToken 应返回错误(refresh_token 为空)") + } +} + +func TestOAuthService_RefreshAccountToken_Success(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + if refreshToken != "account-refresh-token" { + t.Errorf("refreshToken 不匹配: got=%q", refreshToken) + } + return &oauth.TokenResponse{ + AccessToken: "refreshed-access", + TokenType: "Bearer", + ExpiresIn: 3600, + RefreshToken: "new-refresh", + }, nil + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + account := &Account{ + ID: 3, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-access", + "refresh_token": "account-refresh-token", + }, + } + + tokenInfo, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + if tokenInfo.AccessToken != "refreshed-access" { + t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken) + } +} + +func TestOAuthService_RefreshAccountToken_WithProxy(t *testing.T) { + t.Parallel() + + proxyRepo := &mockProxyRepoForOAuth{ + getByIDFunc: func(ctx context.Context, id int64) (*Proxy, error) { + return &Proxy{ + Protocol: "socks5", + Host: "socks.example.com", + Port: 1080, + Username: "user", + Password: "pass", + }, nil + }, + } + + client := &mockClaudeOAuthClient{ + refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + if proxyURL != "socks5://user:pass@socks.example.com:1080" { + t.Errorf("proxyURL 不匹配: got=%q", proxyURL) + } + return &oauth.TokenResponse{ + AccessToken: "refreshed", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewOAuthService(proxyRepo, client) + defer svc.Stop() + + proxyID := int64(10) + account := &Account{ + ID: 4, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + ProxyID: &proxyID, + Credentials: map[string]any{ + "refresh_token": "rt-with-proxy", + }, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } +} + +func TestOAuthService_ExchangeCode_NilOrg(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { + return &oauth.TokenResponse{ + AccessToken: "token-no-org", + TokenType: "Bearer", + ExpiresIn: 3600, + Organization: nil, + Account: nil, + }, nil + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + result, _ := svc.GenerateAuthURL(context.Background(), nil) + tokenInfo, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{ + SessionID: result.SessionID, + Code: "code", + }) + if err != nil { + t.Fatalf("ExchangeCode 返回错误: %v", err) + } + if tokenInfo.OrgUUID != "" { + t.Fatalf("OrgUUID 应为空: got=%q", tokenInfo.OrgUUID) + } + if tokenInfo.AccountUUID != "" { + t.Fatalf("AccountUUID 应为空: got=%q", tokenInfo.AccountUUID) + } +} + +func TestOAuthService_Stop_NoPanic(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + + // 调用 Stop 不应 panic + svc.Stop() + + // 多次调用也不应 panic + svc.Stop() +} diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go new file mode 100644 index 00000000..99013ce5 --- /dev/null +++ b/backend/internal/service/openai_account_scheduler.go @@ -0,0 +1,909 @@ +package service + +import ( + "container/heap" + "context" + "errors" + "hash/fnv" + "math" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +const ( + openAIAccountScheduleLayerPreviousResponse = "previous_response_id" + openAIAccountScheduleLayerSessionSticky = "session_hash" + openAIAccountScheduleLayerLoadBalance = "load_balance" +) + +type OpenAIAccountScheduleRequest struct { + GroupID *int64 + SessionHash string + StickyAccountID int64 + PreviousResponseID string + RequestedModel string + RequiredTransport OpenAIUpstreamTransport + ExcludedIDs map[int64]struct{} +} + +type OpenAIAccountScheduleDecision struct { + Layer string + StickyPreviousHit bool + StickySessionHit bool + CandidateCount int + TopK int + LatencyMs int64 + LoadSkew float64 + SelectedAccountID int64 + SelectedAccountType string +} + +type OpenAIAccountSchedulerMetricsSnapshot struct { + SelectTotal int64 + StickyPreviousHitTotal int64 + StickySessionHitTotal int64 + LoadBalanceSelectTotal int64 + AccountSwitchTotal int64 + SchedulerLatencyMsTotal int64 + SchedulerLatencyMsAvg float64 + StickyHitRatio float64 + AccountSwitchRate float64 + LoadSkewAvg float64 + RuntimeStatsAccountCount int +} + +type OpenAIAccountScheduler interface { + Select(ctx context.Context, req OpenAIAccountScheduleRequest) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) + ReportResult(accountID int64, success bool, firstTokenMs *int) + ReportSwitch() + SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot +} + +type openAIAccountSchedulerMetrics struct { + selectTotal atomic.Int64 + stickyPreviousHitTotal atomic.Int64 + stickySessionHitTotal atomic.Int64 + loadBalanceSelectTotal atomic.Int64 + accountSwitchTotal atomic.Int64 + latencyMsTotal atomic.Int64 + loadSkewMilliTotal atomic.Int64 +} + +func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) { + if m == nil { + return + } + m.selectTotal.Add(1) + m.latencyMsTotal.Add(decision.LatencyMs) + m.loadSkewMilliTotal.Add(int64(math.Round(decision.LoadSkew * 1000))) + if decision.StickyPreviousHit { + m.stickyPreviousHitTotal.Add(1) + } + if decision.StickySessionHit { + m.stickySessionHitTotal.Add(1) + } + if decision.Layer == openAIAccountScheduleLayerLoadBalance { + m.loadBalanceSelectTotal.Add(1) + } +} + +func (m *openAIAccountSchedulerMetrics) recordSwitch() { + if m == nil { + return + } + m.accountSwitchTotal.Add(1) +} + +type openAIAccountRuntimeStats struct { + accounts sync.Map + accountCount atomic.Int64 +} + +type openAIAccountRuntimeStat struct { + errorRateEWMABits atomic.Uint64 + ttftEWMABits atomic.Uint64 +} + +func newOpenAIAccountRuntimeStats() *openAIAccountRuntimeStats { + return &openAIAccountRuntimeStats{} +} + +func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccountRuntimeStat { + if value, ok := s.accounts.Load(accountID); ok { + stat, _ := value.(*openAIAccountRuntimeStat) + if stat != nil { + return stat + } + } + + stat := &openAIAccountRuntimeStat{} + stat.ttftEWMABits.Store(math.Float64bits(math.NaN())) + actual, loaded := s.accounts.LoadOrStore(accountID, stat) + if !loaded { + s.accountCount.Add(1) + return stat + } + existing, _ := actual.(*openAIAccountRuntimeStat) + if existing != nil { + return existing + } + return stat +} + +func updateEWMAAtomic(target *atomic.Uint64, sample float64, alpha float64) { + for { + oldBits := target.Load() + oldValue := math.Float64frombits(oldBits) + newValue := alpha*sample + (1-alpha)*oldValue + if target.CompareAndSwap(oldBits, math.Float64bits(newValue)) { + return + } + } +} + +func (s *openAIAccountRuntimeStats) report(accountID int64, success bool, firstTokenMs *int) { + if s == nil || accountID <= 0 { + return + } + const alpha = 0.2 + stat := s.loadOrCreate(accountID) + + errorSample := 1.0 + if success { + errorSample = 0.0 + } + updateEWMAAtomic(&stat.errorRateEWMABits, errorSample, alpha) + + if firstTokenMs != nil && *firstTokenMs > 0 { + ttft := float64(*firstTokenMs) + ttftBits := math.Float64bits(ttft) + for { + oldBits := stat.ttftEWMABits.Load() + oldValue := math.Float64frombits(oldBits) + if math.IsNaN(oldValue) { + if stat.ttftEWMABits.CompareAndSwap(oldBits, ttftBits) { + break + } + continue + } + newValue := alpha*ttft + (1-alpha)*oldValue + if stat.ttftEWMABits.CompareAndSwap(oldBits, math.Float64bits(newValue)) { + break + } + } + } +} + +func (s *openAIAccountRuntimeStats) snapshot(accountID int64) (errorRate float64, ttft float64, hasTTFT bool) { + if s == nil || accountID <= 0 { + return 0, 0, false + } + value, ok := s.accounts.Load(accountID) + if !ok { + return 0, 0, false + } + stat, _ := value.(*openAIAccountRuntimeStat) + if stat == nil { + return 0, 0, false + } + errorRate = clamp01(math.Float64frombits(stat.errorRateEWMABits.Load())) + ttftValue := math.Float64frombits(stat.ttftEWMABits.Load()) + if math.IsNaN(ttftValue) { + return errorRate, 0, false + } + return errorRate, ttftValue, true +} + +func (s *openAIAccountRuntimeStats) size() int { + if s == nil { + return 0 + } + return int(s.accountCount.Load()) +} + +type defaultOpenAIAccountScheduler struct { + service *OpenAIGatewayService + metrics openAIAccountSchedulerMetrics + stats *openAIAccountRuntimeStats +} + +func newDefaultOpenAIAccountScheduler(service *OpenAIGatewayService, stats *openAIAccountRuntimeStats) OpenAIAccountScheduler { + if stats == nil { + stats = newOpenAIAccountRuntimeStats() + } + return &defaultOpenAIAccountScheduler{ + service: service, + stats: stats, + } +} + +func (s *defaultOpenAIAccountScheduler) Select( + ctx context.Context, + req OpenAIAccountScheduleRequest, +) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { + decision := OpenAIAccountScheduleDecision{} + start := time.Now() + defer func() { + decision.LatencyMs = time.Since(start).Milliseconds() + s.metrics.recordSelect(decision) + }() + + previousResponseID := strings.TrimSpace(req.PreviousResponseID) + if previousResponseID != "" { + selection, err := s.service.SelectAccountByPreviousResponseID( + ctx, + req.GroupID, + previousResponseID, + req.RequestedModel, + req.ExcludedIDs, + ) + if err != nil { + return nil, decision, err + } + if selection != nil && selection.Account != nil { + if !s.isAccountTransportCompatible(selection.Account, req.RequiredTransport) { + selection = nil + } + } + if selection != nil && selection.Account != nil { + decision.Layer = openAIAccountScheduleLayerPreviousResponse + decision.StickyPreviousHit = true + decision.SelectedAccountID = selection.Account.ID + decision.SelectedAccountType = selection.Account.Type + if req.SessionHash != "" { + _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, selection.Account.ID) + } + return selection, decision, nil + } + } + + selection, err := s.selectBySessionHash(ctx, req) + if err != nil { + return nil, decision, err + } + if selection != nil && selection.Account != nil { + decision.Layer = openAIAccountScheduleLayerSessionSticky + decision.StickySessionHit = true + decision.SelectedAccountID = selection.Account.ID + decision.SelectedAccountType = selection.Account.Type + return selection, decision, nil + } + + selection, candidateCount, topK, loadSkew, err := s.selectByLoadBalance(ctx, req) + decision.Layer = openAIAccountScheduleLayerLoadBalance + decision.CandidateCount = candidateCount + decision.TopK = topK + decision.LoadSkew = loadSkew + if err != nil { + return nil, decision, err + } + if selection != nil && selection.Account != nil { + decision.SelectedAccountID = selection.Account.ID + decision.SelectedAccountType = selection.Account.Type + } + return selection, decision, nil +} + +func (s *defaultOpenAIAccountScheduler) selectBySessionHash( + ctx context.Context, + req OpenAIAccountScheduleRequest, +) (*AccountSelectionResult, error) { + sessionHash := strings.TrimSpace(req.SessionHash) + if sessionHash == "" || s == nil || s.service == nil || s.service.cache == nil { + return nil, nil + } + + accountID := req.StickyAccountID + if accountID <= 0 { + var err error + accountID, err = s.service.getStickySessionAccountID(ctx, req.GroupID, sessionHash) + if err != nil || accountID <= 0 { + return nil, nil + } + } + if accountID <= 0 { + return nil, nil + } + if req.ExcludedIDs != nil { + if _, excluded := req.ExcludedIDs[accountID]; excluded { + return nil, nil + } + } + + account, err := s.service.getSchedulableAccount(ctx, accountID) + if err != nil || account == nil { + _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) + return nil, nil + } + if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() { + _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) + return nil, nil + } + if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { + return nil, nil + } + if !s.isAccountTransportCompatible(account, req.RequiredTransport) { + _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) + return nil, nil + } + + result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if acquireErr == nil && result.Acquired { + _ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL()) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + + cfg := s.service.schedulingConfig() + if s.service.concurrencyService != nil { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + return nil, nil +} + +type openAIAccountCandidateScore struct { + account *Account + loadInfo *AccountLoadInfo + score float64 + errorRate float64 + ttft float64 + hasTTFT bool +} + +type openAIAccountCandidateHeap []openAIAccountCandidateScore + +func (h openAIAccountCandidateHeap) Len() int { + return len(h) +} + +func (h openAIAccountCandidateHeap) Less(i, j int) bool { + // 最小堆根节点保存“最差”候选,便于 O(log k) 维护 topK。 + return isOpenAIAccountCandidateBetter(h[j], h[i]) +} + +func (h openAIAccountCandidateHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *openAIAccountCandidateHeap) Push(x any) { + candidate, ok := x.(openAIAccountCandidateScore) + if !ok { + panic("openAIAccountCandidateHeap: invalid element type") + } + *h = append(*h, candidate) +} + +func (h *openAIAccountCandidateHeap) Pop() any { + old := *h + n := len(old) + last := old[n-1] + *h = old[:n-1] + return last +} + +func isOpenAIAccountCandidateBetter(left openAIAccountCandidateScore, right openAIAccountCandidateScore) bool { + if left.score != right.score { + return left.score > right.score + } + if left.account.Priority != right.account.Priority { + return left.account.Priority < right.account.Priority + } + if left.loadInfo.LoadRate != right.loadInfo.LoadRate { + return left.loadInfo.LoadRate < right.loadInfo.LoadRate + } + if left.loadInfo.WaitingCount != right.loadInfo.WaitingCount { + return left.loadInfo.WaitingCount < right.loadInfo.WaitingCount + } + return left.account.ID < right.account.ID +} + +func selectTopKOpenAICandidates(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore { + if len(candidates) == 0 { + return nil + } + if topK <= 0 { + topK = 1 + } + if topK >= len(candidates) { + ranked := append([]openAIAccountCandidateScore(nil), candidates...) + sort.Slice(ranked, func(i, j int) bool { + return isOpenAIAccountCandidateBetter(ranked[i], ranked[j]) + }) + return ranked + } + + best := make(openAIAccountCandidateHeap, 0, topK) + for _, candidate := range candidates { + if len(best) < topK { + heap.Push(&best, candidate) + continue + } + if isOpenAIAccountCandidateBetter(candidate, best[0]) { + best[0] = candidate + heap.Fix(&best, 0) + } + } + + ranked := make([]openAIAccountCandidateScore, len(best)) + copy(ranked, best) + sort.Slice(ranked, func(i, j int) bool { + return isOpenAIAccountCandidateBetter(ranked[i], ranked[j]) + }) + return ranked +} + +type openAISelectionRNG struct { + state uint64 +} + +func newOpenAISelectionRNG(seed uint64) openAISelectionRNG { + if seed == 0 { + seed = 0x9e3779b97f4a7c15 + } + return openAISelectionRNG{state: seed} +} + +func (r *openAISelectionRNG) nextUint64() uint64 { + // xorshift64* + x := r.state + x ^= x >> 12 + x ^= x << 25 + x ^= x >> 27 + r.state = x + return x * 2685821657736338717 +} + +func (r *openAISelectionRNG) nextFloat64() float64 { + // [0,1) + return float64(r.nextUint64()>>11) / (1 << 53) +} + +func deriveOpenAISelectionSeed(req OpenAIAccountScheduleRequest) uint64 { + hasher := fnv.New64a() + writeValue := func(value string) { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return + } + _, _ = hasher.Write([]byte(trimmed)) + _, _ = hasher.Write([]byte{0}) + } + + writeValue(req.SessionHash) + writeValue(req.PreviousResponseID) + writeValue(req.RequestedModel) + if req.GroupID != nil { + _, _ = hasher.Write([]byte(strconv.FormatInt(*req.GroupID, 10))) + } + + seed := hasher.Sum64() + // 对“无会话锚点”的纯负载均衡请求引入时间熵,避免固定命中同一账号。 + if strings.TrimSpace(req.SessionHash) == "" && strings.TrimSpace(req.PreviousResponseID) == "" { + seed ^= uint64(time.Now().UnixNano()) + } + if seed == 0 { + seed = uint64(time.Now().UnixNano()) ^ 0x9e3779b97f4a7c15 + } + return seed +} + +func buildOpenAIWeightedSelectionOrder( + candidates []openAIAccountCandidateScore, + req OpenAIAccountScheduleRequest, +) []openAIAccountCandidateScore { + if len(candidates) <= 1 { + return append([]openAIAccountCandidateScore(nil), candidates...) + } + + pool := append([]openAIAccountCandidateScore(nil), candidates...) + weights := make([]float64, len(pool)) + minScore := pool[0].score + for i := 1; i < len(pool); i++ { + if pool[i].score < minScore { + minScore = pool[i].score + } + } + for i := range pool { + // 将 top-K 分值平移到正区间,避免“单一最高分账号”长期垄断。 + weight := (pool[i].score - minScore) + 1.0 + if math.IsNaN(weight) || math.IsInf(weight, 0) || weight <= 0 { + weight = 1.0 + } + weights[i] = weight + } + + order := make([]openAIAccountCandidateScore, 0, len(pool)) + rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req)) + for len(pool) > 0 { + total := 0.0 + for _, w := range weights { + total += w + } + + selectedIdx := 0 + if total > 0 { + r := rng.nextFloat64() * total + acc := 0.0 + for i, w := range weights { + acc += w + if r <= acc { + selectedIdx = i + break + } + } + } else { + selectedIdx = int(rng.nextUint64() % uint64(len(pool))) + } + + order = append(order, pool[selectedIdx]) + pool = append(pool[:selectedIdx], pool[selectedIdx+1:]...) + weights = append(weights[:selectedIdx], weights[selectedIdx+1:]...) + } + return order +} + +func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( + ctx context.Context, + req OpenAIAccountScheduleRequest, +) (*AccountSelectionResult, int, int, float64, error) { + accounts, err := s.service.listSchedulableAccounts(ctx, req.GroupID) + if err != nil { + return nil, 0, 0, 0, err + } + if len(accounts) == 0 { + return nil, 0, 0, 0, errors.New("no available OpenAI accounts") + } + + filtered := make([]*Account, 0, len(accounts)) + loadReq := make([]AccountWithConcurrency, 0, len(accounts)) + for i := range accounts { + account := &accounts[i] + if req.ExcludedIDs != nil { + if _, excluded := req.ExcludedIDs[account.ID]; excluded { + continue + } + } + if !account.IsSchedulable() || !account.IsOpenAI() { + continue + } + if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { + continue + } + if !s.isAccountTransportCompatible(account, req.RequiredTransport) { + continue + } + filtered = append(filtered, account) + loadReq = append(loadReq, AccountWithConcurrency{ + ID: account.ID, + MaxConcurrency: account.Concurrency, + }) + } + if len(filtered) == 0 { + return nil, 0, 0, 0, errors.New("no available OpenAI accounts") + } + + loadMap := map[int64]*AccountLoadInfo{} + if s.service.concurrencyService != nil { + if batchLoad, loadErr := s.service.concurrencyService.GetAccountsLoadBatch(ctx, loadReq); loadErr == nil { + loadMap = batchLoad + } + } + + minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority + maxWaiting := 1 + loadRateSum := 0.0 + loadRateSumSquares := 0.0 + minTTFT, maxTTFT := 0.0, 0.0 + hasTTFTSample := false + candidates := make([]openAIAccountCandidateScore, 0, len(filtered)) + for _, account := range filtered { + loadInfo := loadMap[account.ID] + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: account.ID} + } + if account.Priority < minPriority { + minPriority = account.Priority + } + if account.Priority > maxPriority { + maxPriority = account.Priority + } + if loadInfo.WaitingCount > maxWaiting { + maxWaiting = loadInfo.WaitingCount + } + errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID) + if hasTTFT && ttft > 0 { + if !hasTTFTSample { + minTTFT, maxTTFT = ttft, ttft + hasTTFTSample = true + } else { + if ttft < minTTFT { + minTTFT = ttft + } + if ttft > maxTTFT { + maxTTFT = ttft + } + } + } + loadRate := float64(loadInfo.LoadRate) + loadRateSum += loadRate + loadRateSumSquares += loadRate * loadRate + candidates = append(candidates, openAIAccountCandidateScore{ + account: account, + loadInfo: loadInfo, + errorRate: errorRate, + ttft: ttft, + hasTTFT: hasTTFT, + }) + } + loadSkew := calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates)) + + weights := s.service.openAIWSSchedulerWeights() + for i := range candidates { + item := &candidates[i] + priorityFactor := 1.0 + if maxPriority > minPriority { + priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority) + } + loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0) + queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting)) + errorFactor := 1 - clamp01(item.errorRate) + ttftFactor := 0.5 + if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT { + ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT)) + } + + item.score = weights.Priority*priorityFactor + + weights.Load*loadFactor + + weights.Queue*queueFactor + + weights.ErrorRate*errorFactor + + weights.TTFT*ttftFactor + } + + topK := s.service.openAIWSLBTopK() + if topK > len(candidates) { + topK = len(candidates) + } + if topK <= 0 { + topK = 1 + } + rankedCandidates := selectTopKOpenAICandidates(candidates, topK) + selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req) + + for i := 0; i < len(selectionOrder); i++ { + candidate := selectionOrder[i] + result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency) + if acquireErr != nil { + return nil, len(candidates), topK, loadSkew, acquireErr + } + if result != nil && result.Acquired { + if req.SessionHash != "" { + _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID) + } + return &AccountSelectionResult{ + Account: candidate.account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, len(candidates), topK, loadSkew, nil + } + } + + cfg := s.service.schedulingConfig() + candidate := selectionOrder[0] + return &AccountSelectionResult{ + Account: candidate.account, + WaitPlan: &AccountWaitPlan{ + AccountID: candidate.account.ID, + MaxConcurrency: candidate.account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, len(candidates), topK, loadSkew, nil +} + +func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool { + // HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。 + if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE { + return true + } + if s == nil || s.service == nil || account == nil { + return false + } + return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport +} + +func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) { + if s == nil || s.stats == nil { + return + } + s.stats.report(accountID, success, firstTokenMs) +} + +func (s *defaultOpenAIAccountScheduler) ReportSwitch() { + if s == nil { + return + } + s.metrics.recordSwitch() +} + +func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot { + if s == nil { + return OpenAIAccountSchedulerMetricsSnapshot{} + } + + selectTotal := s.metrics.selectTotal.Load() + prevHit := s.metrics.stickyPreviousHitTotal.Load() + sessionHit := s.metrics.stickySessionHitTotal.Load() + switchTotal := s.metrics.accountSwitchTotal.Load() + latencyTotal := s.metrics.latencyMsTotal.Load() + loadSkewTotal := s.metrics.loadSkewMilliTotal.Load() + + snapshot := OpenAIAccountSchedulerMetricsSnapshot{ + SelectTotal: selectTotal, + StickyPreviousHitTotal: prevHit, + StickySessionHitTotal: sessionHit, + LoadBalanceSelectTotal: s.metrics.loadBalanceSelectTotal.Load(), + AccountSwitchTotal: switchTotal, + SchedulerLatencyMsTotal: latencyTotal, + RuntimeStatsAccountCount: s.stats.size(), + } + if selectTotal > 0 { + snapshot.SchedulerLatencyMsAvg = float64(latencyTotal) / float64(selectTotal) + snapshot.StickyHitRatio = float64(prevHit+sessionHit) / float64(selectTotal) + snapshot.AccountSwitchRate = float64(switchTotal) / float64(selectTotal) + snapshot.LoadSkewAvg = float64(loadSkewTotal) / 1000 / float64(selectTotal) + } + return snapshot +} + +func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler { + if s == nil { + return nil + } + s.openaiSchedulerOnce.Do(func() { + if s.openaiAccountStats == nil { + s.openaiAccountStats = newOpenAIAccountRuntimeStats() + } + if s.openaiScheduler == nil { + s.openaiScheduler = newDefaultOpenAIAccountScheduler(s, s.openaiAccountStats) + } + }) + return s.openaiScheduler +} + +func (s *OpenAIGatewayService) SelectAccountWithScheduler( + ctx context.Context, + groupID *int64, + previousResponseID string, + sessionHash string, + requestedModel string, + excludedIDs map[int64]struct{}, + requiredTransport OpenAIUpstreamTransport, +) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { + decision := OpenAIAccountScheduleDecision{} + scheduler := s.getOpenAIAccountScheduler() + if scheduler == nil { + selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs) + decision.Layer = openAIAccountScheduleLayerLoadBalance + return selection, decision, err + } + + var stickyAccountID int64 + if sessionHash != "" && s.cache != nil { + if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 { + stickyAccountID = accountID + } + } + + return scheduler.Select(ctx, OpenAIAccountScheduleRequest{ + GroupID: groupID, + SessionHash: sessionHash, + StickyAccountID: stickyAccountID, + PreviousResponseID: previousResponseID, + RequestedModel: requestedModel, + RequiredTransport: requiredTransport, + ExcludedIDs: excludedIDs, + }) +} + +func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) { + scheduler := s.getOpenAIAccountScheduler() + if scheduler == nil { + return + } + scheduler.ReportResult(accountID, success, firstTokenMs) +} + +func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() { + scheduler := s.getOpenAIAccountScheduler() + if scheduler == nil { + return + } + scheduler.ReportSwitch() +} + +func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot { + scheduler := s.getOpenAIAccountScheduler() + if scheduler == nil { + return OpenAIAccountSchedulerMetricsSnapshot{} + } + return scheduler.SnapshotMetrics() +} + +func (s *OpenAIGatewayService) openAIWSSessionStickyTTL() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second + } + return openaiStickySessionTTL +} + +func (s *OpenAIGatewayService) openAIWSLBTopK() int { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.LBTopK > 0 { + return s.cfg.Gateway.OpenAIWS.LBTopK + } + return 7 +} + +func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedulerScoreWeightsView { + if s != nil && s.cfg != nil { + return GatewayOpenAIWSSchedulerScoreWeightsView{ + Priority: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority, + Load: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load, + Queue: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue, + ErrorRate: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate, + TTFT: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT, + } + } + return GatewayOpenAIWSSchedulerScoreWeightsView{ + Priority: 1.0, + Load: 1.0, + Queue: 0.7, + ErrorRate: 0.8, + TTFT: 0.5, + } +} + +type GatewayOpenAIWSSchedulerScoreWeightsView struct { + Priority float64 + Load float64 + Queue float64 + ErrorRate float64 + TTFT float64 +} + +func clamp01(value float64) float64 { + switch { + case value < 0: + return 0 + case value > 1: + return 1 + default: + return value + } +} + +func calcLoadSkewByMoments(sum float64, sumSquares float64, count int) float64 { + if count <= 1 { + return 0 + } + mean := sum / float64(count) + variance := sumSquares/float64(count) - mean*mean + if variance < 0 { + variance = 0 + } + return math.Sqrt(variance) +} diff --git a/backend/internal/service/openai_account_scheduler_benchmark_test.go b/backend/internal/service/openai_account_scheduler_benchmark_test.go new file mode 100644 index 00000000..897be5b0 --- /dev/null +++ b/backend/internal/service/openai_account_scheduler_benchmark_test.go @@ -0,0 +1,83 @@ +package service + +import ( + "sort" + "testing" +) + +func buildOpenAISchedulerBenchmarkCandidates(size int) []openAIAccountCandidateScore { + if size <= 0 { + return nil + } + candidates := make([]openAIAccountCandidateScore, 0, size) + for i := 0; i < size; i++ { + accountID := int64(10_000 + i) + candidates = append(candidates, openAIAccountCandidateScore{ + account: &Account{ + ID: accountID, + Priority: i % 7, + }, + loadInfo: &AccountLoadInfo{ + AccountID: accountID, + LoadRate: (i * 17) % 100, + WaitingCount: (i * 11) % 13, + }, + score: float64((i*29)%1000) / 100, + errorRate: float64((i * 5) % 100 / 100), + ttft: float64(30 + (i*3)%500), + hasTTFT: i%3 != 0, + }) + } + return candidates +} + +func selectTopKOpenAICandidatesBySortBenchmark(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore { + if len(candidates) == 0 { + return nil + } + if topK <= 0 { + topK = 1 + } + ranked := append([]openAIAccountCandidateScore(nil), candidates...) + sort.Slice(ranked, func(i, j int) bool { + return isOpenAIAccountCandidateBetter(ranked[i], ranked[j]) + }) + if topK > len(ranked) { + topK = len(ranked) + } + return ranked[:topK] +} + +func BenchmarkOpenAIAccountSchedulerSelectTopK(b *testing.B) { + cases := []struct { + name string + size int + topK int + }{ + {name: "n_16_k_3", size: 16, topK: 3}, + {name: "n_64_k_3", size: 64, topK: 3}, + {name: "n_256_k_5", size: 256, topK: 5}, + } + + for _, tc := range cases { + candidates := buildOpenAISchedulerBenchmarkCandidates(tc.size) + b.Run(tc.name+"/heap_topk", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + result := selectTopKOpenAICandidates(candidates, tc.topK) + if len(result) == 0 { + b.Fatal("unexpected empty result") + } + } + }) + b.Run(tc.name+"/full_sort", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + result := selectTopKOpenAICandidatesBySortBenchmark(candidates, tc.topK) + if len(result) == 0 { + b.Fatal("unexpected empty result") + } + } + }) + } +} diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go new file mode 100644 index 00000000..7f6f1b66 --- /dev/null +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -0,0 +1,841 @@ +package service + +import ( + "context" + "fmt" + "math" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) { + ctx := context.Background() + groupID := int64(9) + account := Account{ + ID: 1001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 2, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800 + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + store := svc.getOpenAIWSStateStore() + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_001", account.ID, time.Hour)) + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "resp_prev_001", + "session_hash_001", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, account.ID, selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer) + require.True(t, decision.StickyPreviousHit) + require.Equal(t, account.ID, cache.sessionBindings["openai:session_hash_001"]) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testing.T) { + ctx := context.Background() + groupID := int64(10) + account := Account{ + ID: 2001, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_abc": account.ID, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: &config.Config{}, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "session_hash_abc", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, account.ID, selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer) + require.True(t, decision.StickySessionHit) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsSticky(t *testing.T) { + ctx := context.Background() + groupID := int64(10100) + accounts := []Account{ + { + ID: 21001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + { + ID: 21002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 9, + }, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_sticky_busy": 21001, + }, + } + cfg := &config.Config{} + cfg.Gateway.Scheduling.StickySessionMaxWaiting = 2 + cfg.Gateway.Scheduling.StickySessionWaitTimeout = 45 * time.Second + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + concurrencyCache := stubConcurrencyCache{ + acquireResults: map[int64]bool{ + 21001: false, // sticky 账号已满 + 21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换) + }, + waitCounts: map[int64]int{ + 21001: 999, + }, + loadMap: map[int64]*AccountLoadInfo{ + 21001: {AccountID: 21001, LoadRate: 90, WaitingCount: 9}, + 21002: {AccountID: 21002, LoadRate: 1, WaitingCount: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "session_hash_sticky_busy", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(21001), selection.Account.ID, "busy sticky account should remain selected") + require.False(t, selection.Acquired) + require.NotNil(t, selection.WaitPlan) + require.Equal(t, int64(21001), selection.WaitPlan.AccountID) + require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer) + require.True(t, decision.StickySessionHit) +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP(t *testing.T) { + ctx := context.Background() + groupID := int64(1010) + account := Account{ + ID: 2101, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "openai_ws_force_http": true, + }, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_force_http": account.ID, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: &config.Config{}, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "session_hash_force_http", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, account.ID, selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer) + require.True(t, decision.StickySessionHit) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStickyHTTPAccount(t *testing.T) { + ctx := context.Background() + groupID := int64(1011) + accounts := []Account{ + { + ID: 2201, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + { + ID: 2202, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_ws_only": 2201, + }, + } + cfg := newOpenAIWSV2TestConfig() + + // 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。 + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0}, + 2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "session_hash_ws_only", + "gpt-5.1", + nil, + OpenAIUpstreamTransportResponsesWebsocketV2, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(2202), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.False(t, decision.StickySessionHit) + require.Equal(t, 1, decision.CandidateCount) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailableAccount(t *testing.T) { + ctx := context.Background() + groupID := int64(1012) + accounts := []Account{ + { + ID: 2301, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{}, + cfg: newOpenAIWSV2TestConfig(), + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportResponsesWebsocketV2, + ) + require.Error(t, err) + require.Nil(t, selection) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.Equal(t, 0, decision.CandidateCount) +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback(t *testing.T) { + ctx := context.Background() + groupID := int64(11) + accounts := []Account{ + { + ID: 3001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + { + ID: 3002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + { + ID: 3003, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.4 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8}, + 3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1}, + 3003: {AccountID: 3003, LoadRate: 10, WaitingCount: 0}, + }, + acquireResults: map[int64]bool{ + 3003: false, // top1 失败,必须回退到 top-K 的下一候选 + 3002: true, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(3002), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.Equal(t, 3, decision.CandidateCount) + require.Equal(t, 2, decision.TopK) + require.Greater(t, decision.LoadSkew, 0.0) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) { + ctx := context.Background() + groupID := int64(12) + account := Account{ + ID: 4001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_metrics": account.ID, + }, + } + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: &config.Config{}, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny) + require.NoError(t, err) + require.NotNil(t, selection) + svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120)) + svc.RecordOpenAIAccountSwitch() + + snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics() + require.GreaterOrEqual(t, snapshot.SelectTotal, int64(1)) + require.GreaterOrEqual(t, snapshot.StickySessionHitTotal, int64(1)) + require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1)) + require.GreaterOrEqual(t, snapshot.SchedulerLatencyMsAvg, float64(0)) + require.GreaterOrEqual(t, snapshot.StickyHitRatio, 0.0) + require.GreaterOrEqual(t, snapshot.RuntimeStatsAccountCount, 1) +} + +func intPtrForTest(v int) *int { + return &v +} + +func TestOpenAIAccountRuntimeStats_ReportAndSnapshot(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + stats.report(1001, true, nil) + firstTTFT := 100 + stats.report(1001, false, &firstTTFT) + secondTTFT := 200 + stats.report(1001, false, &secondTTFT) + + errorRate, ttft, hasTTFT := stats.snapshot(1001) + require.True(t, hasTTFT) + require.InDelta(t, 0.36, errorRate, 1e-9) + require.InDelta(t, 120.0, ttft, 1e-9) + require.Equal(t, 1, stats.size()) +} + +func TestOpenAIAccountRuntimeStats_ReportConcurrent(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + + const ( + accountCount = 4 + workers = 16 + iterations = 800 + ) + var wg sync.WaitGroup + wg.Add(workers) + for worker := 0; worker < workers; worker++ { + worker := worker + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + accountID := int64(i%accountCount + 1) + success := (i+worker)%3 != 0 + ttft := 80 + (i+worker)%40 + stats.report(accountID, success, &ttft) + } + }() + } + wg.Wait() + + require.Equal(t, accountCount, stats.size()) + for accountID := int64(1); accountID <= accountCount; accountID++ { + errorRate, ttft, hasTTFT := stats.snapshot(accountID) + require.GreaterOrEqual(t, errorRate, 0.0) + require.LessOrEqual(t, errorRate, 1.0) + require.True(t, hasTTFT) + require.Greater(t, ttft, 0.0) + } +} + +func TestSelectTopKOpenAICandidates(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + { + account: &Account{ID: 11, Priority: 2}, + loadInfo: &AccountLoadInfo{LoadRate: 10, WaitingCount: 1}, + score: 10.0, + }, + { + account: &Account{ID: 12, Priority: 1}, + loadInfo: &AccountLoadInfo{LoadRate: 20, WaitingCount: 1}, + score: 9.5, + }, + { + account: &Account{ID: 13, Priority: 1}, + loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 0}, + score: 10.0, + }, + { + account: &Account{ID: 14, Priority: 0}, + loadInfo: &AccountLoadInfo{LoadRate: 40, WaitingCount: 0}, + score: 8.0, + }, + } + + top2 := selectTopKOpenAICandidates(candidates, 2) + require.Len(t, top2, 2) + require.Equal(t, int64(13), top2[0].account.ID) + require.Equal(t, int64(11), top2[1].account.ID) + + topAll := selectTopKOpenAICandidates(candidates, 8) + require.Len(t, topAll, len(candidates)) + require.Equal(t, int64(13), topAll[0].account.ID) + require.Equal(t, int64(11), topAll[1].account.ID) + require.Equal(t, int64(12), topAll[2].account.ID) + require.Equal(t, int64(14), topAll[3].account.ID) +} + +func TestBuildOpenAIWeightedSelectionOrder_DeterministicBySessionSeed(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + { + account: &Account{ID: 101}, + loadInfo: &AccountLoadInfo{LoadRate: 10, WaitingCount: 0}, + score: 4.2, + }, + { + account: &Account{ID: 102}, + loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 1}, + score: 3.5, + }, + { + account: &Account{ID: 103}, + loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 2}, + score: 2.1, + }, + } + req := OpenAIAccountScheduleRequest{ + GroupID: int64PtrForTest(99), + SessionHash: "session_seed_fixed", + RequestedModel: "gpt-5.1", + } + + first := buildOpenAIWeightedSelectionOrder(candidates, req) + second := buildOpenAIWeightedSelectionOrder(candidates, req) + require.Len(t, first, len(candidates)) + require.Len(t, second, len(candidates)) + for i := range first { + require.Equal(t, first[i].account.ID, second[i].account.ID) + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesAcrossSessions(t *testing.T) { + ctx := context.Background() + groupID := int64(15) + accounts := []Account{ + { + ID: 5101, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 3, + Priority: 0, + }, + { + ID: 5102, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 3, + Priority: 0, + }, + { + ID: 5103, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 3, + Priority: 0, + }, + } + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 3 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1}, + 5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1}, + 5103: {AccountID: 5103, LoadRate: 20, WaitingCount: 1}, + }, + } + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selected := make(map[int64]int, len(accounts)) + for i := 0; i < 60; i++ { + sessionHash := fmt.Sprintf("session_hash_lb_%d", i) + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + sessionHash, + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + selected[selection.Account.ID]++ + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + + // 多 session 应该能打散到多个账号,避免“恒定单账号命中”。 + require.GreaterOrEqual(t, len(selected), 2) +} + +func TestDeriveOpenAISelectionSeed_NoAffinityAddsEntropy(t *testing.T) { + req := OpenAIAccountScheduleRequest{ + RequestedModel: "gpt-5.1", + } + seed1 := deriveOpenAISelectionSeed(req) + time.Sleep(1 * time.Millisecond) + seed2 := deriveOpenAISelectionSeed(req) + require.NotZero(t, seed1) + require.NotZero(t, seed2) + require.NotEqual(t, seed1, seed2) +} + +func TestBuildOpenAIWeightedSelectionOrder_HandlesInvalidScores(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + { + account: &Account{ID: 901}, + loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0}, + score: math.NaN(), + }, + { + account: &Account{ID: 902}, + loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0}, + score: math.Inf(1), + }, + { + account: &Account{ID: 903}, + loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0}, + score: -1, + }, + } + req := OpenAIAccountScheduleRequest{ + SessionHash: "seed_invalid_scores", + } + + order := buildOpenAIWeightedSelectionOrder(candidates, req) + require.Len(t, order, len(candidates)) + seen := map[int64]struct{}{} + for _, item := range order { + seen[item.account.ID] = struct{}{} + } + require.Len(t, seen, len(candidates)) +} + +func TestOpenAISelectionRNG_SeedZeroStillWorks(t *testing.T) { + rng := newOpenAISelectionRNG(0) + v1 := rng.nextUint64() + v2 := rng.nextUint64() + require.NotEqual(t, v1, v2) + require.GreaterOrEqual(t, rng.nextFloat64(), 0.0) + require.Less(t, rng.nextFloat64(), 1.0) +} + +func TestOpenAIAccountCandidateHeap_PushPopAndInvalidType(t *testing.T) { + h := openAIAccountCandidateHeap{} + h.Push(openAIAccountCandidateScore{ + account: &Account{ID: 7001}, + loadInfo: &AccountLoadInfo{LoadRate: 0, WaitingCount: 0}, + score: 1.0, + }) + require.Equal(t, 1, h.Len()) + popped, ok := h.Pop().(openAIAccountCandidateScore) + require.True(t, ok) + require.Equal(t, int64(7001), popped.account.ID) + require.Equal(t, 0, h.Len()) + + require.Panics(t, func() { + h.Push("bad_element_type") + }) +} + +func TestClamp01_AllBranches(t *testing.T) { + require.Equal(t, 0.0, clamp01(-0.2)) + require.Equal(t, 1.0, clamp01(1.3)) + require.Equal(t, 0.5, clamp01(0.5)) +} + +func TestCalcLoadSkewByMoments_Branches(t *testing.T) { + require.Equal(t, 0.0, calcLoadSkewByMoments(1, 1, 1)) + // variance < 0 分支:sumSquares/count - mean^2 为负值时应钳制为 0。 + require.Equal(t, 0.0, calcLoadSkewByMoments(1, 0, 2)) + require.GreaterOrEqual(t, calcLoadSkewByMoments(6, 20, 3), 0.0) +} + +func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) { + schedulerAny := newDefaultOpenAIAccountScheduler(&OpenAIGatewayService{}, nil) + scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler) + require.True(t, ok) + + ttft := 100 + scheduler.ReportResult(1001, true, &ttft) + scheduler.ReportSwitch() + scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{ + Layer: openAIAccountScheduleLayerLoadBalance, + LatencyMs: 8, + LoadSkew: 0.5, + StickyPreviousHit: true, + }) + scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{ + Layer: openAIAccountScheduleLayerSessionSticky, + LatencyMs: 6, + LoadSkew: 0.2, + StickySessionHit: true, + }) + + snapshot := scheduler.SnapshotMetrics() + require.Equal(t, int64(2), snapshot.SelectTotal) + require.Equal(t, int64(1), snapshot.StickyPreviousHitTotal) + require.Equal(t, int64(1), snapshot.StickySessionHitTotal) + require.Equal(t, int64(1), snapshot.LoadBalanceSelectTotal) + require.Equal(t, int64(1), snapshot.AccountSwitchTotal) + require.Greater(t, snapshot.SchedulerLatencyMsAvg, 0.0) + require.Greater(t, snapshot.StickyHitRatio, 0.0) + require.Greater(t, snapshot.LoadSkewAvg, 0.0) +} + +func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) { + svc := &OpenAIGatewayService{} + ttft := 120 + svc.ReportOpenAIAccountScheduleResult(10, true, &ttft) + svc.RecordOpenAIAccountSwitch() + snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics() + require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1)) + require.Equal(t, 7, svc.openAIWSLBTopK()) + require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL()) + + defaultWeights := svc.openAIWSSchedulerWeights() + require.Equal(t, 1.0, defaultWeights.Priority) + require.Equal(t, 1.0, defaultWeights.Load) + require.Equal(t, 0.7, defaultWeights.Queue) + require.Equal(t, 0.8, defaultWeights.ErrorRate) + require.Equal(t, 0.5, defaultWeights.TTFT) + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 9 + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 180 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0.3 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.4 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.5 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.6 + svcWithCfg := &OpenAIGatewayService{cfg: cfg} + + require.Equal(t, 9, svcWithCfg.openAIWSLBTopK()) + require.Equal(t, 180*time.Second, svcWithCfg.openAIWSSessionStickyTTL()) + customWeights := svcWithCfg.openAIWSSchedulerWeights() + require.Equal(t, 0.2, customWeights.Priority) + require.Equal(t, 0.3, customWeights.Load) + require.Equal(t, 0.4, customWeights.Queue) + require.Equal(t, 0.5, customWeights.ErrorRate) + require.Equal(t, 0.6, customWeights.TTFT) +} + +func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *testing.T) { + scheduler := &defaultOpenAIAccountScheduler{} + require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportAny)) + require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE)) + require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2)) + + cfg := newOpenAIWSV2TestConfig() + scheduler.service = &OpenAIGatewayService{cfg: cfg} + account := &Account{ + ID: 8801, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + require.True(t, scheduler.isAccountTransportCompatible(account, OpenAIUpstreamTransportResponsesWebsocketV2)) +} + +func int64PtrForTest(v int64) *int64 { + return &v +} diff --git a/backend/internal/service/openai_client_restriction_detector.go b/backend/internal/service/openai_client_restriction_detector.go new file mode 100644 index 00000000..d1784e11 --- /dev/null +++ b/backend/internal/service/openai_client_restriction_detector.go @@ -0,0 +1,86 @@ +package service + +import ( + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/gin-gonic/gin" +) + +const ( + // CodexClientRestrictionReasonDisabled 表示账号未开启 codex_cli_only。 + CodexClientRestrictionReasonDisabled = "codex_cli_only_disabled" + // CodexClientRestrictionReasonMatchedUA 表示请求命中官方客户端 UA 白名单。 + CodexClientRestrictionReasonMatchedUA = "official_client_user_agent_matched" + // CodexClientRestrictionReasonMatchedOriginator 表示请求命中官方客户端 originator 白名单。 + CodexClientRestrictionReasonMatchedOriginator = "official_client_originator_matched" + // CodexClientRestrictionReasonNotMatchedUA 表示请求未命中官方客户端 UA 白名单。 + CodexClientRestrictionReasonNotMatchedUA = "official_client_user_agent_not_matched" + // CodexClientRestrictionReasonForceCodexCLI 表示通过 ForceCodexCLI 配置兜底放行。 + CodexClientRestrictionReasonForceCodexCLI = "force_codex_cli_enabled" +) + +// CodexClientRestrictionDetectionResult 是 codex_cli_only 统一检测入口结果。 +type CodexClientRestrictionDetectionResult struct { + Enabled bool + Matched bool + Reason string +} + +// CodexClientRestrictionDetector 定义 codex_cli_only 统一检测入口。 +type CodexClientRestrictionDetector interface { + Detect(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult +} + +// OpenAICodexClientRestrictionDetector 为 OpenAI OAuth codex_cli_only 的默认实现。 +type OpenAICodexClientRestrictionDetector struct { + cfg *config.Config +} + +func NewOpenAICodexClientRestrictionDetector(cfg *config.Config) *OpenAICodexClientRestrictionDetector { + return &OpenAICodexClientRestrictionDetector{cfg: cfg} +} + +func (d *OpenAICodexClientRestrictionDetector) Detect(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult { + if account == nil || !account.IsCodexCLIOnlyEnabled() { + return CodexClientRestrictionDetectionResult{ + Enabled: false, + Matched: false, + Reason: CodexClientRestrictionReasonDisabled, + } + } + + if d != nil && d.cfg != nil && d.cfg.Gateway.ForceCodexCLI { + return CodexClientRestrictionDetectionResult{ + Enabled: true, + Matched: true, + Reason: CodexClientRestrictionReasonForceCodexCLI, + } + } + + userAgent := "" + originator := "" + if c != nil { + userAgent = c.GetHeader("User-Agent") + originator = c.GetHeader("originator") + } + if openai.IsCodexOfficialClientRequest(userAgent) { + return CodexClientRestrictionDetectionResult{ + Enabled: true, + Matched: true, + Reason: CodexClientRestrictionReasonMatchedUA, + } + } + if openai.IsCodexOfficialClientOriginator(originator) { + return CodexClientRestrictionDetectionResult{ + Enabled: true, + Matched: true, + Reason: CodexClientRestrictionReasonMatchedOriginator, + } + } + + return CodexClientRestrictionDetectionResult{ + Enabled: true, + Matched: false, + Reason: CodexClientRestrictionReasonNotMatchedUA, + } +} diff --git a/backend/internal/service/openai_client_restriction_detector_test.go b/backend/internal/service/openai_client_restriction_detector_test.go new file mode 100644 index 00000000..984b4ff6 --- /dev/null +++ b/backend/internal/service/openai_client_restriction_detector_test.go @@ -0,0 +1,124 @@ +package service + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func newCodexDetectorTestContext(ua string, originator string) *gin.Context { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + if ua != "" { + c.Request.Header.Set("User-Agent", ua) + } + if originator != "" { + c.Request.Header.Set("originator", originator) + } + return c +} + +func TestOpenAICodexClientRestrictionDetector_Detect(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("未开启开关时绕过", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}} + + result := detector.Detect(newCodexDetectorTestContext("curl/8.0", ""), account) + require.False(t, result.Enabled) + require.False(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonDisabled, result.Reason) + }) + + t.Run("开启后 codex_cli_rs 命中", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{"codex_cli_only": true}, + } + + result := detector.Detect(newCodexDetectorTestContext("codex_cli_rs/0.99.0", ""), account) + require.True(t, result.Enabled) + require.True(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason) + }) + + t.Run("开启后 codex_vscode 命中", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{"codex_cli_only": true}, + } + + result := detector.Detect(newCodexDetectorTestContext("codex_vscode/1.0.0", ""), account) + require.True(t, result.Enabled) + require.True(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason) + }) + + t.Run("开启后 codex_app 命中", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{"codex_cli_only": true}, + } + + result := detector.Detect(newCodexDetectorTestContext("codex_app/2.1.0", ""), account) + require.True(t, result.Enabled) + require.True(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason) + }) + + t.Run("开启后 originator 命中", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{"codex_cli_only": true}, + } + + result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "codex_chatgpt_desktop"), account) + require.True(t, result.Enabled) + require.True(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonMatchedOriginator, result.Reason) + }) + + t.Run("开启后非官方客户端拒绝", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{"codex_cli_only": true}, + } + + result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "my_client"), account) + require.True(t, result.Enabled) + require.False(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonNotMatchedUA, result.Reason) + }) + + t.Run("开启 ForceCodexCLI 时允许通过", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(&config.Config{ + Gateway: config.GatewayConfig{ForceCodexCLI: true}, + }) + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{"codex_cli_only": true}, + } + + result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "my_client"), account) + require.True(t, result.Enabled) + require.True(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonForceCodexCLI, result.Reason) + }) +} diff --git a/backend/internal/service/openai_client_transport.go b/backend/internal/service/openai_client_transport.go new file mode 100644 index 00000000..c9cf3246 --- /dev/null +++ b/backend/internal/service/openai_client_transport.go @@ -0,0 +1,71 @@ +package service + +import ( + "strings" + + "github.com/gin-gonic/gin" +) + +// OpenAIClientTransport 表示客户端入站协议类型。 +type OpenAIClientTransport string + +const ( + OpenAIClientTransportUnknown OpenAIClientTransport = "" + OpenAIClientTransportHTTP OpenAIClientTransport = "http" + OpenAIClientTransportWS OpenAIClientTransport = "ws" +) + +const openAIClientTransportContextKey = "openai_client_transport" + +// SetOpenAIClientTransport 标记当前请求的客户端入站协议。 +func SetOpenAIClientTransport(c *gin.Context, transport OpenAIClientTransport) { + if c == nil { + return + } + normalized := normalizeOpenAIClientTransport(transport) + if normalized == OpenAIClientTransportUnknown { + return + } + c.Set(openAIClientTransportContextKey, string(normalized)) +} + +// GetOpenAIClientTransport 读取当前请求的客户端入站协议。 +func GetOpenAIClientTransport(c *gin.Context) OpenAIClientTransport { + if c == nil { + return OpenAIClientTransportUnknown + } + raw, ok := c.Get(openAIClientTransportContextKey) + if !ok || raw == nil { + return OpenAIClientTransportUnknown + } + + switch v := raw.(type) { + case OpenAIClientTransport: + return normalizeOpenAIClientTransport(v) + case string: + return normalizeOpenAIClientTransport(OpenAIClientTransport(v)) + default: + return OpenAIClientTransportUnknown + } +} + +func normalizeOpenAIClientTransport(transport OpenAIClientTransport) OpenAIClientTransport { + switch strings.ToLower(strings.TrimSpace(string(transport))) { + case string(OpenAIClientTransportHTTP), "http_sse", "sse": + return OpenAIClientTransportHTTP + case string(OpenAIClientTransportWS), "websocket": + return OpenAIClientTransportWS + default: + return OpenAIClientTransportUnknown + } +} + +func resolveOpenAIWSDecisionByClientTransport( + decision OpenAIWSProtocolDecision, + clientTransport OpenAIClientTransport, +) OpenAIWSProtocolDecision { + if clientTransport == OpenAIClientTransportHTTP { + return openAIWSHTTPDecision("client_protocol_http") + } + return decision +} diff --git a/backend/internal/service/openai_client_transport_test.go b/backend/internal/service/openai_client_transport_test.go new file mode 100644 index 00000000..ef90e614 --- /dev/null +++ b/backend/internal/service/openai_client_transport_test.go @@ -0,0 +1,107 @@ +package service + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestOpenAIClientTransport_SetAndGet(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + require.Equal(t, OpenAIClientTransportUnknown, GetOpenAIClientTransport(c)) + + SetOpenAIClientTransport(c, OpenAIClientTransportHTTP) + require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c)) + + SetOpenAIClientTransport(c, OpenAIClientTransportWS) + require.Equal(t, OpenAIClientTransportWS, GetOpenAIClientTransport(c)) +} + +func TestOpenAIClientTransport_GetNormalizesRawContextValue(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + rawValue any + want OpenAIClientTransport + }{ + { + name: "type_value_ws", + rawValue: OpenAIClientTransportWS, + want: OpenAIClientTransportWS, + }, + { + name: "http_sse_alias", + rawValue: "http_sse", + want: OpenAIClientTransportHTTP, + }, + { + name: "sse_alias", + rawValue: "sSe", + want: OpenAIClientTransportHTTP, + }, + { + name: "websocket_alias", + rawValue: "WebSocket", + want: OpenAIClientTransportWS, + }, + { + name: "invalid_string", + rawValue: "tcp", + want: OpenAIClientTransportUnknown, + }, + { + name: "invalid_type", + rawValue: 123, + want: OpenAIClientTransportUnknown, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Set(openAIClientTransportContextKey, tt.rawValue) + require.Equal(t, tt.want, GetOpenAIClientTransport(c)) + }) + } +} + +func TestOpenAIClientTransport_NilAndUnknownInput(t *testing.T) { + SetOpenAIClientTransport(nil, OpenAIClientTransportHTTP) + require.Equal(t, OpenAIClientTransportUnknown, GetOpenAIClientTransport(nil)) + + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + SetOpenAIClientTransport(c, OpenAIClientTransportUnknown) + _, exists := c.Get(openAIClientTransportContextKey) + require.False(t, exists) + + SetOpenAIClientTransport(c, OpenAIClientTransport(" ")) + _, exists = c.Get(openAIClientTransportContextKey) + require.False(t, exists) +} + +func TestResolveOpenAIWSDecisionByClientTransport(t *testing.T) { + base := OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocketV2, + Reason: "ws_v2_enabled", + } + + httpDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportHTTP) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, httpDecision.Transport) + require.Equal(t, "client_protocol_http", httpDecision.Reason) + + wsDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportWS) + require.Equal(t, base, wsDecision) + + unknownDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportUnknown) + require.Equal(t, base, unknownDecision) +} diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index cea81693..16befb82 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -2,73 +2,66 @@ package service import ( _ "embed" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "path/filepath" "strings" - "time" -) - -const ( - opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt" - codexCacheTTL = 15 * time.Minute ) //go:embed prompts/codex_cli_instructions.md var codexCLIInstructions string var codexModelMap = map[string]string{ - "gpt-5.3": "gpt-5.3", - "gpt-5.3-none": "gpt-5.3", - "gpt-5.3-low": "gpt-5.3", - "gpt-5.3-medium": "gpt-5.3", - "gpt-5.3-high": "gpt-5.3", - "gpt-5.3-xhigh": "gpt-5.3", - "gpt-5.3-codex": "gpt-5.3-codex", - "gpt-5.3-codex-low": "gpt-5.3-codex", - "gpt-5.3-codex-medium": "gpt-5.3-codex", - "gpt-5.3-codex-high": "gpt-5.3-codex", - "gpt-5.3-codex-xhigh": "gpt-5.3-codex", - "gpt-5.1-codex": "gpt-5.1-codex", - "gpt-5.1-codex-low": "gpt-5.1-codex", - "gpt-5.1-codex-medium": "gpt-5.1-codex", - "gpt-5.1-codex-high": "gpt-5.1-codex", - "gpt-5.1-codex-max": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-low": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-medium": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-high": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max", - "gpt-5.2": "gpt-5.2", - "gpt-5.2-none": "gpt-5.2", - "gpt-5.2-low": "gpt-5.2", - "gpt-5.2-medium": "gpt-5.2", - "gpt-5.2-high": "gpt-5.2", - "gpt-5.2-xhigh": "gpt-5.2", - "gpt-5.2-codex": "gpt-5.2-codex", - "gpt-5.2-codex-low": "gpt-5.2-codex", - "gpt-5.2-codex-medium": "gpt-5.2-codex", - "gpt-5.2-codex-high": "gpt-5.2-codex", - "gpt-5.2-codex-xhigh": "gpt-5.2-codex", - "gpt-5.1-codex-mini": "gpt-5.1-codex-mini", - "gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini", - "gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini", - "gpt-5.1": "gpt-5.1", - "gpt-5.1-none": "gpt-5.1", - "gpt-5.1-low": "gpt-5.1", - "gpt-5.1-medium": "gpt-5.1", - "gpt-5.1-high": "gpt-5.1", - "gpt-5.1-chat-latest": "gpt-5.1", - "gpt-5-codex": "gpt-5.1-codex", - "codex-mini-latest": "gpt-5.1-codex-mini", - "gpt-5-codex-mini": "gpt-5.1-codex-mini", - "gpt-5-codex-mini-medium": "gpt-5.1-codex-mini", - "gpt-5-codex-mini-high": "gpt-5.1-codex-mini", - "gpt-5": "gpt-5.1", - "gpt-5-mini": "gpt-5.1", - "gpt-5-nano": "gpt-5.1", + "gpt-5.3": "gpt-5.3-codex", + "gpt-5.3-none": "gpt-5.3-codex", + "gpt-5.3-low": "gpt-5.3-codex", + "gpt-5.3-medium": "gpt-5.3-codex", + "gpt-5.3-high": "gpt-5.3-codex", + "gpt-5.3-xhigh": "gpt-5.3-codex", + "gpt-5.3-codex": "gpt-5.3-codex", + "gpt-5.3-codex-spark": "gpt-5.3-codex", + "gpt-5.3-codex-spark-low": "gpt-5.3-codex", + "gpt-5.3-codex-spark-medium": "gpt-5.3-codex", + "gpt-5.3-codex-spark-high": "gpt-5.3-codex", + "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", + "gpt-5.3-codex-low": "gpt-5.3-codex", + "gpt-5.3-codex-medium": "gpt-5.3-codex", + "gpt-5.3-codex-high": "gpt-5.3-codex", + "gpt-5.3-codex-xhigh": "gpt-5.3-codex", + "gpt-5.1-codex": "gpt-5.1-codex", + "gpt-5.1-codex-low": "gpt-5.1-codex", + "gpt-5.1-codex-medium": "gpt-5.1-codex", + "gpt-5.1-codex-high": "gpt-5.1-codex", + "gpt-5.1-codex-max": "gpt-5.1-codex-max", + "gpt-5.1-codex-max-low": "gpt-5.1-codex-max", + "gpt-5.1-codex-max-medium": "gpt-5.1-codex-max", + "gpt-5.1-codex-max-high": "gpt-5.1-codex-max", + "gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max", + "gpt-5.2": "gpt-5.2", + "gpt-5.2-none": "gpt-5.2", + "gpt-5.2-low": "gpt-5.2", + "gpt-5.2-medium": "gpt-5.2", + "gpt-5.2-high": "gpt-5.2", + "gpt-5.2-xhigh": "gpt-5.2", + "gpt-5.2-codex": "gpt-5.2-codex", + "gpt-5.2-codex-low": "gpt-5.2-codex", + "gpt-5.2-codex-medium": "gpt-5.2-codex", + "gpt-5.2-codex-high": "gpt-5.2-codex", + "gpt-5.2-codex-xhigh": "gpt-5.2-codex", + "gpt-5.1-codex-mini": "gpt-5.1-codex-mini", + "gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini", + "gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini", + "gpt-5.1": "gpt-5.1", + "gpt-5.1-none": "gpt-5.1", + "gpt-5.1-low": "gpt-5.1", + "gpt-5.1-medium": "gpt-5.1", + "gpt-5.1-high": "gpt-5.1", + "gpt-5.1-chat-latest": "gpt-5.1", + "gpt-5-codex": "gpt-5.1-codex", + "codex-mini-latest": "gpt-5.1-codex-mini", + "gpt-5-codex-mini": "gpt-5.1-codex-mini", + "gpt-5-codex-mini-medium": "gpt-5.1-codex-mini", + "gpt-5-codex-mini-high": "gpt-5.1-codex-mini", + "gpt-5": "gpt-5.1", + "gpt-5-mini": "gpt-5.1", + "gpt-5-nano": "gpt-5.1", } type codexTransformResult struct { @@ -77,12 +70,6 @@ type codexTransformResult struct { PromptCacheKey string } -type opencodeCacheMetadata struct { - ETag string `json:"etag"` - LastFetch string `json:"lastFetch,omitempty"` - LastChecked int64 `json:"lastChecked"` -} - func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTransformResult { result := codexTransformResult{} // 工具续链需求会影响存储策略与 input 过滤逻辑。 @@ -112,13 +99,19 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTran result.Modified = true } - if _, ok := reqBody["max_output_tokens"]; ok { - delete(reqBody, "max_output_tokens") - result.Modified = true - } - if _, ok := reqBody["max_completion_tokens"]; ok { - delete(reqBody, "max_completion_tokens") - result.Modified = true + // Strip parameters unsupported by codex models via the Responses API. + for _, key := range []string{ + "max_output_tokens", + "max_completion_tokens", + "temperature", + "top_p", + "frequency_penalty", + "presence_penalty", + } { + if _, ok := reqBody[key]; ok { + delete(reqBody, key) + result.Modified = true + } } if normalizeCodexTools(reqBody) { @@ -171,7 +164,7 @@ func normalizeCodexModel(model string) string { return "gpt-5.3-codex" } if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") { - return "gpt-5.3" + return "gpt-5.3-codex" } if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") { return "gpt-5.1-codex-max" @@ -216,54 +209,9 @@ func getNormalizedCodexModel(modelID string) string { return "" } -func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string { - cacheDir := codexCachePath("") - if cacheDir == "" { - return "" - } - cacheFile := filepath.Join(cacheDir, cacheFileName) - metaFile := filepath.Join(cacheDir, metaFileName) - - var cachedContent string - if content, ok := readFile(cacheFile); ok { - cachedContent = content - } - - var meta opencodeCacheMetadata - if loadJSON(metaFile, &meta) && meta.LastChecked > 0 && cachedContent != "" { - if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL { - return cachedContent - } - } - - content, etag, status, err := fetchWithETag(url, meta.ETag) - if err == nil && status == http.StatusNotModified && cachedContent != "" { - return cachedContent - } - if err == nil && status >= 200 && status < 300 && content != "" { - _ = writeFile(cacheFile, content) - meta = opencodeCacheMetadata{ - ETag: etag, - LastFetch: time.Now().UTC().Format(time.RFC3339), - LastChecked: time.Now().UnixMilli(), - } - _ = writeJSON(metaFile, meta) - return content - } - - return cachedContent -} - func getOpenCodeCodexHeader() string { - // 优先从 opencode 仓库缓存获取指令。 - opencodeInstructions := getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json") - - // 若 opencode 指令可用,直接返回。 - if opencodeInstructions != "" { - return opencodeInstructions - } - - // 否则回退使用本地 Codex CLI 指令。 + // 兼容保留:历史上这里会从 opencode 仓库拉取 codex_header.txt。 + // 现在我们与 Codex CLI 一致,直接使用仓库内置的 instructions,避免读写缓存与外网依赖。 return getCodexCLIInstructions() } @@ -281,8 +229,8 @@ func GetCodexCLIInstructions() string { } // applyInstructions 处理 instructions 字段 -// isCodexCLI=true: 仅补充缺失的 instructions(使用 opencode 指令) -// isCodexCLI=false: 优先使用 opencode 指令覆盖 +// isCodexCLI=true: 仅补充缺失的 instructions(使用内置 Codex CLI 指令) +// isCodexCLI=false: 优先使用内置 Codex CLI 指令覆盖 func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool { if isCodexCLI { return applyCodexCLIInstructions(reqBody) @@ -291,13 +239,13 @@ func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool { } // applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions -// 仅在 instructions 为空时添加 opencode 指令 +// 仅在 instructions 为空时添加内置 Codex CLI 指令(不依赖 opencode 缓存/回源) func applyCodexCLIInstructions(reqBody map[string]any) bool { if !isInstructionsEmpty(reqBody) { return false // 已有有效 instructions,不修改 } - instructions := strings.TrimSpace(getOpenCodeCodexHeader()) + instructions := strings.TrimSpace(getCodexCLIInstructions()) if instructions != "" { reqBody["instructions"] = instructions return true @@ -306,8 +254,8 @@ func applyCodexCLIInstructions(reqBody map[string]any) bool { return false } -// applyOpenCodeInstructions 为非 Codex CLI 请求应用 opencode 指令 -// 优先使用 opencode 指令覆盖 +// applyOpenCodeInstructions 为非 Codex CLI 请求应用内置 Codex CLI 指令(兼容历史函数名) +// 优先使用内置 Codex CLI 指令覆盖 func applyOpenCodeInstructions(reqBody map[string]any) bool { instructions := strings.TrimSpace(getOpenCodeCodexHeader()) existingInstructions, _ := reqBody["instructions"].(string) @@ -489,85 +437,3 @@ func normalizeCodexTools(reqBody map[string]any) bool { return modified } - -func codexCachePath(filename string) string { - home, err := os.UserHomeDir() - if err != nil { - return "" - } - cacheDir := filepath.Join(home, ".opencode", "cache") - if filename == "" { - return cacheDir - } - return filepath.Join(cacheDir, filename) -} - -func readFile(path string) (string, bool) { - if path == "" { - return "", false - } - data, err := os.ReadFile(path) - if err != nil { - return "", false - } - return string(data), true -} - -func writeFile(path, content string) error { - if path == "" { - return fmt.Errorf("empty cache path") - } - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return err - } - return os.WriteFile(path, []byte(content), 0o644) -} - -func loadJSON(path string, target any) bool { - data, err := os.ReadFile(path) - if err != nil { - return false - } - if err := json.Unmarshal(data, target); err != nil { - return false - } - return true -} - -func writeJSON(path string, value any) error { - if path == "" { - return fmt.Errorf("empty json path") - } - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return err - } - data, err := json.Marshal(value) - if err != nil { - return err - } - return os.WriteFile(path, data, 0o644) -} - -func fetchWithETag(url, etag string) (string, string, int, error) { - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return "", "", 0, err - } - req.Header.Set("User-Agent", "sub2api-codex") - if etag != "" { - req.Header.Set("If-None-Match", etag) - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return "", "", 0, err - } - defer func() { - _ = resp.Body.Close() - }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", "", resp.StatusCode, err - } - return string(body), resp.Header.Get("etag"), resp.StatusCode, nil -} diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index cc0acafc..27093f6c 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -1,18 +1,13 @@ package service import ( - "encoding/json" - "os" - "path/filepath" "testing" - "time" "github.com/stretchr/testify/require" ) func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { // 续链场景:保留 item_reference 与 id,但不再强制 store=true。 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.2", @@ -48,7 +43,6 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { // 续链场景:显式 store=false 不再强制为 true,保持 false。 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.1", @@ -68,7 +62,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) { // 显式 store=true 也会强制为 false。 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.1", @@ -88,7 +81,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) { func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(t *testing.T) { // 非续链场景:未设置 store 时默认 false,并移除 input 中的 id。 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.1", @@ -130,8 +122,6 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) { } func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunctionTools(t *testing.T) { - setupCodexCache(t) - reqBody := map[string]any{ "model": "gpt-5.1", "tools": []any{ @@ -162,7 +152,6 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { // 空 input 应保持为空且不触发异常。 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.1", @@ -178,97 +167,39 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { func TestNormalizeCodexModel_Gpt53(t *testing.T) { cases := map[string]string{ - "gpt-5.3": "gpt-5.3", - "gpt-5.3-codex": "gpt-5.3-codex", - "gpt-5.3-codex-xhigh": "gpt-5.3-codex", - "gpt 5.3 codex": "gpt-5.3-codex", + "gpt-5.3": "gpt-5.3-codex", + "gpt-5.3-codex": "gpt-5.3-codex", + "gpt-5.3-codex-xhigh": "gpt-5.3-codex", + "gpt-5.3-codex-spark": "gpt-5.3-codex", + "gpt-5.3-codex-spark-high": "gpt-5.3-codex", + "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", + "gpt 5.3 codex": "gpt-5.3-codex", } for input, expected := range cases { require.Equal(t, expected, normalizeCodexModel(input)) } - } func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) { - // Codex CLI 场景:已有 instructions 时保持不变 - setupCodexCache(t) + // Codex CLI 场景:已有 instructions 时不修改 reqBody := map[string]any{ "model": "gpt-5.1", - "instructions": "user custom instructions", - "input": []any{}, + "instructions": "existing instructions", } - result := applyCodexOAuthTransform(reqBody, true) + result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true instructions, ok := reqBody["instructions"].(string) require.True(t, ok) - require.Equal(t, "user custom instructions", instructions) - // instructions 未变,但其他字段(如 store、stream)可能被修改 - require.True(t, result.Modified) -} - -func TestApplyCodexOAuthTransform_CodexCLI_AddsInstructionsWhenEmpty(t *testing.T) { - // Codex CLI 场景:无 instructions 时补充内置指令 - setupCodexCache(t) - - reqBody := map[string]any{ - "model": "gpt-5.1", - "input": []any{}, - } - - result := applyCodexOAuthTransform(reqBody, true) - - instructions, ok := reqBody["instructions"].(string) - require.True(t, ok) - require.NotEmpty(t, instructions) - require.True(t, result.Modified) -} - -func TestApplyCodexOAuthTransform_NonCodexCLI_UsesOpenCodeInstructions(t *testing.T) { - // 非 Codex CLI 场景:使用 opencode 指令(缓存中有 header) - setupCodexCache(t) - - reqBody := map[string]any{ - "model": "gpt-5.1", - "input": []any{}, - } - - result := applyCodexOAuthTransform(reqBody, false) - - instructions, ok := reqBody["instructions"].(string) - require.True(t, ok) - require.Equal(t, "header", instructions) // setupCodexCache 设置的缓存内容 - require.True(t, result.Modified) -} - -func setupCodexCache(t *testing.T) { - t.Helper() - - // 使用临时 HOME 避免触发网络拉取 header。 - // Windows 使用 USERPROFILE,Unix 使用 HOME。 - tempDir := t.TempDir() - t.Setenv("HOME", tempDir) - t.Setenv("USERPROFILE", tempDir) - - cacheDir := filepath.Join(tempDir, ".opencode", "cache") - require.NoError(t, os.MkdirAll(cacheDir, 0o755)) - require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644)) - - meta := map[string]any{ - "etag": "", - "lastFetch": time.Now().UTC().Format(time.RFC3339), - "lastChecked": time.Now().UnixMilli(), - } - data, err := json.Marshal(meta) - require.NoError(t, err) - require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644)) + require.Equal(t, "existing instructions", instructions) + // Modified 仍可能为 true(因为其他字段被修改),但 instructions 应保持不变 + _ = result } func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T) { // Codex CLI 场景:无 instructions 时补充默认值 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.1", @@ -284,8 +215,7 @@ func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T } func TestApplyCodexOAuthTransform_NonCodexCLI_OverridesInstructions(t *testing.T) { - // 非 Codex CLI 场景:使用 opencode 指令覆盖 - setupCodexCache(t) + // 非 Codex CLI 场景:使用内置 Codex CLI 指令覆盖 reqBody := map[string]any{ "model": "gpt-5.1", diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 6c4fe256..8606708f 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -10,20 +10,24 @@ import ( "errors" "fmt" "io" - "log" + "math/rand" "net/http" - "regexp" "sort" "strconv" "strings" + "sync" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.uber.org/zap" ) const ( @@ -32,20 +36,62 @@ const ( // OpenAI Platform API for API Key accounts (fallback) openaiPlatformAPIURL = "https://api.openai.com/v1/responses" openaiStickySessionTTL = time.Hour // 粘性会话TTL + codexCLIUserAgent = "codex_cli_rs/0.104.0" + // codex_cli_only 拒绝时单个请求头日志长度上限(字符) + codexCLIOnlyHeaderValueMaxBytes = 256 + + // OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。 + OpenAIParsedRequestBodyKey = "openai_parsed_request_body" + // OpenAI WS Mode 失败后的重连次数上限(不含首次尝试)。 + // 与 Codex 客户端保持一致:失败后最多重连 5 次。 + openAIWSReconnectRetryLimit = 5 + // OpenAI WS Mode 重连退避默认值(可由配置覆盖)。 + openAIWSRetryBackoffInitialDefault = 120 * time.Millisecond + openAIWSRetryBackoffMaxDefault = 2 * time.Second + openAIWSRetryJitterRatioDefault = 0.2 ) -// openaiSSEDataRe matches SSE data lines with optional whitespace after colon. -// Some upstream APIs return non-standard "data:" without space (should be "data: "). -var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`) - -// OpenAI allowed headers whitelist (for non-OAuth accounts) +// OpenAI allowed headers whitelist (for non-passthrough). var openaiAllowedHeaders = map[string]bool{ - "accept-language": true, - "content-type": true, - "conversation_id": true, - "user-agent": true, - "originator": true, - "session_id": true, + "accept-language": true, + "content-type": true, + "conversation_id": true, + "user-agent": true, + "originator": true, + "session_id": true, + "x-codex-turn-state": true, + "x-codex-turn-metadata": true, +} + +// OpenAI passthrough allowed headers whitelist. +// 透传模式下仅放行这些低风险请求头,避免将非标准/环境噪声头传给上游触发风控。 +var openaiPassthroughAllowedHeaders = map[string]bool{ + "accept": true, + "accept-language": true, + "content-type": true, + "conversation_id": true, + "openai-beta": true, + "user-agent": true, + "originator": true, + "session_id": true, + "x-codex-turn-state": true, + "x-codex-turn-metadata": true, +} + +// codex_cli_only 拒绝时记录的请求头白名单(仅用于诊断日志,不参与上游透传) +var codexCLIOnlyDebugHeaderWhitelist = []string{ + "User-Agent", + "Content-Type", + "Accept", + "Accept-Language", + "OpenAI-Beta", + "Originator", + "Session_ID", + "Conversation_ID", + "X-Request-ID", + "X-Client-Request-ID", + "X-Forwarded-For", + "X-Real-IP", } // OpenAICodexUsageSnapshot represents Codex API usage limits from response headers @@ -163,10 +209,40 @@ type OpenAIForwardResult struct { // Stored for usage records display; nil means not provided / not applicable. ReasoningEffort *string Stream bool + OpenAIWSMode bool Duration time.Duration FirstTokenMs *int } +type OpenAIWSRetryMetricsSnapshot struct { + RetryAttemptsTotal int64 `json:"retry_attempts_total"` + RetryBackoffMsTotal int64 `json:"retry_backoff_ms_total"` + RetryExhaustedTotal int64 `json:"retry_exhausted_total"` + NonRetryableFastFallbackTotal int64 `json:"non_retryable_fast_fallback_total"` +} + +type OpenAICompatibilityFallbackMetricsSnapshot struct { + SessionHashLegacyReadFallbackTotal int64 `json:"session_hash_legacy_read_fallback_total"` + SessionHashLegacyReadFallbackHit int64 `json:"session_hash_legacy_read_fallback_hit"` + SessionHashLegacyDualWriteTotal int64 `json:"session_hash_legacy_dual_write_total"` + SessionHashLegacyReadHitRate float64 `json:"session_hash_legacy_read_hit_rate"` + + MetadataLegacyFallbackIsMaxTokensOneHaikuTotal int64 `json:"metadata_legacy_fallback_is_max_tokens_one_haiku_total"` + MetadataLegacyFallbackThinkingEnabledTotal int64 `json:"metadata_legacy_fallback_thinking_enabled_total"` + MetadataLegacyFallbackPrefetchedStickyAccount int64 `json:"metadata_legacy_fallback_prefetched_sticky_account_total"` + MetadataLegacyFallbackPrefetchedStickyGroup int64 `json:"metadata_legacy_fallback_prefetched_sticky_group_total"` + MetadataLegacyFallbackSingleAccountRetryTotal int64 `json:"metadata_legacy_fallback_single_account_retry_total"` + MetadataLegacyFallbackAccountSwitchCountTotal int64 `json:"metadata_legacy_fallback_account_switch_count_total"` + MetadataLegacyFallbackTotal int64 `json:"metadata_legacy_fallback_total"` +} + +type openAIWSRetryMetrics struct { + retryAttempts atomic.Int64 + retryBackoffMs atomic.Int64 + retryExhausted atomic.Int64 + nonRetryableFastFallback atomic.Int64 +} + // OpenAIGatewayService handles OpenAI API gateway operations type OpenAIGatewayService struct { accountRepo AccountRepository @@ -175,6 +251,7 @@ type OpenAIGatewayService struct { userSubRepo UserSubscriptionRepository cache GatewayCache cfg *config.Config + codexDetector CodexClientRestrictionDetector schedulerSnapshot *SchedulerSnapshotService concurrencyService *ConcurrencyService billingService *BillingService @@ -184,6 +261,19 @@ type OpenAIGatewayService struct { deferredService *DeferredService openAITokenProvider *OpenAITokenProvider toolCorrector *CodexToolCorrector + openaiWSResolver OpenAIWSProtocolResolver + + openaiWSPoolOnce sync.Once + openaiWSStateStoreOnce sync.Once + openaiSchedulerOnce sync.Once + openaiWSPool *openAIWSConnPool + openaiWSStateStore OpenAIWSStateStore + openaiScheduler OpenAIAccountScheduler + openaiAccountStats *openAIAccountRuntimeStats + + openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time + openaiWSRetryMetrics openAIWSRetryMetrics + responseHeaderFilter *responseheaders.CompiledHeaderFilter } // NewOpenAIGatewayService creates a new OpenAIGatewayService @@ -203,23 +293,587 @@ func NewOpenAIGatewayService( deferredService *DeferredService, openAITokenProvider *OpenAITokenProvider, ) *OpenAIGatewayService { - return &OpenAIGatewayService{ - accountRepo: accountRepo, - usageLogRepo: usageLogRepo, - userRepo: userRepo, - userSubRepo: userSubRepo, - cache: cache, - cfg: cfg, - schedulerSnapshot: schedulerSnapshot, - concurrencyService: concurrencyService, - billingService: billingService, - rateLimitService: rateLimitService, - billingCacheService: billingCacheService, - httpUpstream: httpUpstream, - deferredService: deferredService, - openAITokenProvider: openAITokenProvider, - toolCorrector: NewCodexToolCorrector(), + svc := &OpenAIGatewayService{ + accountRepo: accountRepo, + usageLogRepo: usageLogRepo, + userRepo: userRepo, + userSubRepo: userSubRepo, + cache: cache, + cfg: cfg, + codexDetector: NewOpenAICodexClientRestrictionDetector(cfg), + schedulerSnapshot: schedulerSnapshot, + concurrencyService: concurrencyService, + billingService: billingService, + rateLimitService: rateLimitService, + billingCacheService: billingCacheService, + httpUpstream: httpUpstream, + deferredService: deferredService, + openAITokenProvider: openAITokenProvider, + toolCorrector: NewCodexToolCorrector(), + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + responseHeaderFilter: compileResponseHeaderFilter(cfg), } + svc.logOpenAIWSModeBootstrap() + return svc +} + +// CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。 +// 应在应用优雅关闭时调用。 +func (s *OpenAIGatewayService) CloseOpenAIWSPool() { + if s != nil && s.openaiWSPool != nil { + s.openaiWSPool.Close() + } +} + +func (s *OpenAIGatewayService) logOpenAIWSModeBootstrap() { + if s == nil || s.cfg == nil { + return + } + wsCfg := s.cfg.Gateway.OpenAIWS + logOpenAIWSModeInfo( + "bootstrap enabled=%v oauth_enabled=%v apikey_enabled=%v force_http=%v responses_websockets_v2=%v responses_websockets=%v payload_log_sample_rate=%.3f event_flush_batch_size=%d event_flush_interval_ms=%d prewarm_cooldown_ms=%d retry_backoff_initial_ms=%d retry_backoff_max_ms=%d retry_jitter_ratio=%.3f retry_total_budget_ms=%d ws_read_limit_bytes=%d", + wsCfg.Enabled, + wsCfg.OAuthEnabled, + wsCfg.APIKeyEnabled, + wsCfg.ForceHTTP, + wsCfg.ResponsesWebsocketsV2, + wsCfg.ResponsesWebsockets, + wsCfg.PayloadLogSampleRate, + wsCfg.EventFlushBatchSize, + wsCfg.EventFlushIntervalMS, + wsCfg.PrewarmCooldownMS, + wsCfg.RetryBackoffInitialMS, + wsCfg.RetryBackoffMaxMS, + wsCfg.RetryJitterRatio, + wsCfg.RetryTotalBudgetMS, + openAIWSMessageReadLimitBytes, + ) +} + +func (s *OpenAIGatewayService) getCodexClientRestrictionDetector() CodexClientRestrictionDetector { + if s != nil && s.codexDetector != nil { + return s.codexDetector + } + var cfg *config.Config + if s != nil { + cfg = s.cfg + } + return NewOpenAICodexClientRestrictionDetector(cfg) +} + +func (s *OpenAIGatewayService) getOpenAIWSProtocolResolver() OpenAIWSProtocolResolver { + if s != nil && s.openaiWSResolver != nil { + return s.openaiWSResolver + } + var cfg *config.Config + if s != nil { + cfg = s.cfg + } + return NewOpenAIWSProtocolResolver(cfg) +} + +func classifyOpenAIWSReconnectReason(err error) (string, bool) { + if err == nil { + return "", false + } + var fallbackErr *openAIWSFallbackError + if !errors.As(err, &fallbackErr) || fallbackErr == nil { + return "", false + } + reason := strings.TrimSpace(fallbackErr.Reason) + if reason == "" { + return "", false + } + + baseReason := strings.TrimPrefix(reason, "prewarm_") + + switch baseReason { + case "policy_violation", + "message_too_big", + "upgrade_required", + "ws_unsupported", + "auth_failed", + "previous_response_not_found": + return reason, false + } + + switch baseReason { + case "read_event", + "write_request", + "write", + "acquire_timeout", + "acquire_conn", + "conn_queue_full", + "dial_failed", + "upstream_5xx", + "event_error", + "error_event", + "upstream_error_event", + "ws_connection_limit_reached", + "missing_final_response": + return reason, true + default: + return reason, false + } +} + +func resolveOpenAIWSFallbackErrorResponse(err error) (statusCode int, errType string, clientMessage string, upstreamMessage string, ok bool) { + if err == nil { + return 0, "", "", "", false + } + var fallbackErr *openAIWSFallbackError + if !errors.As(err, &fallbackErr) || fallbackErr == nil { + return 0, "", "", "", false + } + + reason := strings.TrimSpace(fallbackErr.Reason) + reason = strings.TrimPrefix(reason, "prewarm_") + if reason == "" { + return 0, "", "", "", false + } + + var dialErr *openAIWSDialError + if fallbackErr.Err != nil && errors.As(fallbackErr.Err, &dialErr) && dialErr != nil { + if dialErr.StatusCode > 0 { + statusCode = dialErr.StatusCode + } + if dialErr.Err != nil { + upstreamMessage = sanitizeUpstreamErrorMessage(strings.TrimSpace(dialErr.Err.Error())) + } + } + + switch reason { + case "previous_response_not_found": + if statusCode == 0 { + statusCode = http.StatusBadRequest + } + errType = "invalid_request_error" + if upstreamMessage == "" { + upstreamMessage = "previous response not found" + } + case "upgrade_required": + if statusCode == 0 { + statusCode = http.StatusUpgradeRequired + } + case "ws_unsupported": + if statusCode == 0 { + statusCode = http.StatusBadRequest + } + case "auth_failed": + if statusCode == 0 { + statusCode = http.StatusUnauthorized + } + case "upstream_rate_limited": + if statusCode == 0 { + statusCode = http.StatusTooManyRequests + } + default: + if statusCode == 0 { + return 0, "", "", "", false + } + } + + if upstreamMessage == "" && fallbackErr.Err != nil { + upstreamMessage = sanitizeUpstreamErrorMessage(strings.TrimSpace(fallbackErr.Err.Error())) + } + if upstreamMessage == "" { + switch reason { + case "upgrade_required": + upstreamMessage = "upstream websocket upgrade required" + case "ws_unsupported": + upstreamMessage = "upstream websocket not supported" + case "auth_failed": + upstreamMessage = "upstream authentication failed" + case "upstream_rate_limited": + upstreamMessage = "upstream rate limit exceeded, please retry later" + default: + upstreamMessage = "Upstream request failed" + } + } + + if errType == "" { + if statusCode == http.StatusTooManyRequests { + errType = "rate_limit_error" + } else { + errType = "upstream_error" + } + } + clientMessage = upstreamMessage + return statusCode, errType, clientMessage, upstreamMessage, true +} + +func (s *OpenAIGatewayService) writeOpenAIWSFallbackErrorResponse(c *gin.Context, account *Account, wsErr error) bool { + if c == nil || c.Writer == nil || c.Writer.Written() { + return false + } + statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse(wsErr) + if !ok { + return false + } + if strings.TrimSpace(clientMessage) == "" { + clientMessage = "Upstream request failed" + } + if strings.TrimSpace(upstreamMessage) == "" { + upstreamMessage = clientMessage + } + + setOpsUpstreamError(c, statusCode, upstreamMessage, "") + if account != nil { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: statusCode, + Kind: "ws_error", + Message: upstreamMessage, + }) + } + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "type": errType, + "message": clientMessage, + }, + }) + return true +} + +func (s *OpenAIGatewayService) openAIWSRetryBackoff(attempt int) time.Duration { + if attempt <= 0 { + return 0 + } + + initial := openAIWSRetryBackoffInitialDefault + maxBackoff := openAIWSRetryBackoffMaxDefault + jitterRatio := openAIWSRetryJitterRatioDefault + if s != nil && s.cfg != nil { + wsCfg := s.cfg.Gateway.OpenAIWS + if wsCfg.RetryBackoffInitialMS > 0 { + initial = time.Duration(wsCfg.RetryBackoffInitialMS) * time.Millisecond + } + if wsCfg.RetryBackoffMaxMS > 0 { + maxBackoff = time.Duration(wsCfg.RetryBackoffMaxMS) * time.Millisecond + } + if wsCfg.RetryJitterRatio >= 0 { + jitterRatio = wsCfg.RetryJitterRatio + } + } + if initial <= 0 { + return 0 + } + if maxBackoff <= 0 { + maxBackoff = initial + } + if maxBackoff < initial { + maxBackoff = initial + } + if jitterRatio < 0 { + jitterRatio = 0 + } + if jitterRatio > 1 { + jitterRatio = 1 + } + + shift := attempt - 1 + if shift < 0 { + shift = 0 + } + backoff := initial + if shift > 0 { + backoff = initial * time.Duration(1< maxBackoff { + backoff = maxBackoff + } + if jitterRatio <= 0 { + return backoff + } + jitter := time.Duration(float64(backoff) * jitterRatio) + if jitter <= 0 { + return backoff + } + delta := time.Duration(rand.Int63n(int64(jitter)*2+1)) - jitter + withJitter := backoff + delta + if withJitter < 0 { + return 0 + } + return withJitter +} + +func (s *OpenAIGatewayService) openAIWSRetryTotalBudget() time.Duration { + if s != nil && s.cfg != nil { + ms := s.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS + if ms <= 0 { + return 0 + } + return time.Duration(ms) * time.Millisecond + } + return 0 +} + +func (s *OpenAIGatewayService) recordOpenAIWSRetryAttempt(backoff time.Duration) { + if s == nil { + return + } + s.openaiWSRetryMetrics.retryAttempts.Add(1) + if backoff > 0 { + s.openaiWSRetryMetrics.retryBackoffMs.Add(backoff.Milliseconds()) + } +} + +func (s *OpenAIGatewayService) recordOpenAIWSRetryExhausted() { + if s == nil { + return + } + s.openaiWSRetryMetrics.retryExhausted.Add(1) +} + +func (s *OpenAIGatewayService) recordOpenAIWSNonRetryableFastFallback() { + if s == nil { + return + } + s.openaiWSRetryMetrics.nonRetryableFastFallback.Add(1) +} + +func (s *OpenAIGatewayService) SnapshotOpenAIWSRetryMetrics() OpenAIWSRetryMetricsSnapshot { + if s == nil { + return OpenAIWSRetryMetricsSnapshot{} + } + return OpenAIWSRetryMetricsSnapshot{ + RetryAttemptsTotal: s.openaiWSRetryMetrics.retryAttempts.Load(), + RetryBackoffMsTotal: s.openaiWSRetryMetrics.retryBackoffMs.Load(), + RetryExhaustedTotal: s.openaiWSRetryMetrics.retryExhausted.Load(), + NonRetryableFastFallbackTotal: s.openaiWSRetryMetrics.nonRetryableFastFallback.Load(), + } +} + +func SnapshotOpenAICompatibilityFallbackMetrics() OpenAICompatibilityFallbackMetricsSnapshot { + legacyReadFallbackTotal, legacyReadFallbackHit, legacyDualWriteTotal := openAIStickyCompatStats() + isMaxTokensOneHaiku, thinkingEnabled, prefetchedStickyAccount, prefetchedStickyGroup, singleAccountRetry, accountSwitchCount := RequestMetadataFallbackStats() + + readHitRate := float64(0) + if legacyReadFallbackTotal > 0 { + readHitRate = float64(legacyReadFallbackHit) / float64(legacyReadFallbackTotal) + } + metadataFallbackTotal := isMaxTokensOneHaiku + thinkingEnabled + prefetchedStickyAccount + prefetchedStickyGroup + singleAccountRetry + accountSwitchCount + + return OpenAICompatibilityFallbackMetricsSnapshot{ + SessionHashLegacyReadFallbackTotal: legacyReadFallbackTotal, + SessionHashLegacyReadFallbackHit: legacyReadFallbackHit, + SessionHashLegacyDualWriteTotal: legacyDualWriteTotal, + SessionHashLegacyReadHitRate: readHitRate, + + MetadataLegacyFallbackIsMaxTokensOneHaikuTotal: isMaxTokensOneHaiku, + MetadataLegacyFallbackThinkingEnabledTotal: thinkingEnabled, + MetadataLegacyFallbackPrefetchedStickyAccount: prefetchedStickyAccount, + MetadataLegacyFallbackPrefetchedStickyGroup: prefetchedStickyGroup, + MetadataLegacyFallbackSingleAccountRetryTotal: singleAccountRetry, + MetadataLegacyFallbackAccountSwitchCountTotal: accountSwitchCount, + MetadataLegacyFallbackTotal: metadataFallbackTotal, + } +} + +func (s *OpenAIGatewayService) detectCodexClientRestriction(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult { + return s.getCodexClientRestrictionDetector().Detect(c, account) +} + +func getAPIKeyIDFromContext(c *gin.Context) int64 { + if c == nil { + return 0 + } + v, exists := c.Get("api_key") + if !exists { + return 0 + } + apiKey, ok := v.(*APIKey) + if !ok || apiKey == nil { + return 0 + } + return apiKey.ID +} + +func logCodexCLIOnlyDetection(ctx context.Context, c *gin.Context, account *Account, apiKeyID int64, result CodexClientRestrictionDetectionResult, body []byte) { + if !result.Enabled { + return + } + if ctx == nil { + ctx = context.Background() + } + accountID := int64(0) + if account != nil { + accountID = account.ID + } + fields := []zap.Field{ + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", accountID), + zap.Bool("codex_cli_only_enabled", result.Enabled), + zap.Bool("codex_official_client_match", result.Matched), + zap.String("reject_reason", result.Reason), + } + if apiKeyID > 0 { + fields = append(fields, zap.Int64("api_key_id", apiKeyID)) + } + if !result.Matched { + fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, body) + } + log := logger.FromContext(ctx).With(fields...) + if result.Matched { + return + } + log.Warn("OpenAI codex_cli_only 拒绝非官方客户端请求") +} + +func appendCodexCLIOnlyRejectedRequestFields(fields []zap.Field, c *gin.Context, body []byte) []zap.Field { + if c == nil || c.Request == nil { + return fields + } + + req := c.Request + requestModel, requestStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) + fields = append(fields, + zap.String("request_method", strings.TrimSpace(req.Method)), + zap.String("request_path", strings.TrimSpace(req.URL.Path)), + zap.String("request_query", strings.TrimSpace(req.URL.RawQuery)), + zap.String("request_host", strings.TrimSpace(req.Host)), + zap.String("request_client_ip", strings.TrimSpace(c.ClientIP())), + zap.String("request_remote_addr", strings.TrimSpace(req.RemoteAddr)), + zap.String("request_user_agent", strings.TrimSpace(req.Header.Get("User-Agent"))), + zap.String("request_content_type", strings.TrimSpace(req.Header.Get("Content-Type"))), + zap.Int64("request_content_length", req.ContentLength), + zap.Bool("request_stream", requestStream), + ) + if requestModel != "" { + fields = append(fields, zap.String("request_model", requestModel)) + } + if promptCacheKey != "" { + fields = append(fields, zap.String("request_prompt_cache_key_sha256", hashSensitiveValueForLog(promptCacheKey))) + } + + if headers := snapshotCodexCLIOnlyHeaders(req.Header); len(headers) > 0 { + fields = append(fields, zap.Any("request_headers", headers)) + } + fields = append(fields, zap.Int("request_body_size", len(body))) + return fields +} + +func snapshotCodexCLIOnlyHeaders(header http.Header) map[string]string { + if len(header) == 0 { + return nil + } + result := make(map[string]string, len(codexCLIOnlyDebugHeaderWhitelist)) + for _, key := range codexCLIOnlyDebugHeaderWhitelist { + value := strings.TrimSpace(header.Get(key)) + if value == "" { + continue + } + result[strings.ToLower(key)] = truncateString(value, codexCLIOnlyHeaderValueMaxBytes) + } + return result +} + +func hashSensitiveValueForLog(raw string) string { + value := strings.TrimSpace(raw) + if value == "" { + return "" + } + sum := sha256.Sum256([]byte(value)) + return hex.EncodeToString(sum[:8]) +} + +func logOpenAIInstructionsRequiredDebug( + ctx context.Context, + c *gin.Context, + account *Account, + upstreamStatusCode int, + upstreamMsg string, + requestBody []byte, + upstreamBody []byte, +) { + msg := strings.TrimSpace(upstreamMsg) + if !isOpenAIInstructionsRequiredError(upstreamStatusCode, msg, upstreamBody) { + return + } + if ctx == nil { + ctx = context.Background() + } + + accountID := int64(0) + accountName := "" + if account != nil { + accountID = account.ID + accountName = strings.TrimSpace(account.Name) + } + + userAgent := "" + if c != nil { + userAgent = strings.TrimSpace(c.GetHeader("User-Agent")) + } + + fields := []zap.Field{ + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", accountID), + zap.String("account_name", accountName), + zap.Int("upstream_status_code", upstreamStatusCode), + zap.String("upstream_error_message", msg), + zap.String("request_user_agent", userAgent), + zap.Bool("codex_official_client_match", openai.IsCodexCLIRequest(userAgent)), + } + fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, requestBody) + + logger.FromContext(ctx).With(fields...).Warn("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查") +} + +func isOpenAIInstructionsRequiredError(upstreamStatusCode int, upstreamMsg string, upstreamBody []byte) bool { + if upstreamStatusCode != http.StatusBadRequest { + return false + } + + hasInstructionRequired := func(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + if strings.Contains(lower, "instructions are required") { + return true + } + if strings.Contains(lower, "required parameter: 'instructions'") { + return true + } + if strings.Contains(lower, "required parameter: instructions") { + return true + } + if strings.Contains(lower, "missing required parameter") && strings.Contains(lower, "instructions") { + return true + } + return strings.Contains(lower, "instruction") && strings.Contains(lower, "required") + } + + if hasInstructionRequired(upstreamMsg) { + return true + } + if len(upstreamBody) == 0 { + return false + } + + errMsg := gjson.GetBytes(upstreamBody, "error.message").String() + errMsgLower := strings.ToLower(strings.TrimSpace(errMsg)) + errCode := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.code").String())) + errParam := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.param").String())) + errType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.type").String())) + + if errParam == "instructions" { + return true + } + if hasInstructionRequired(errMsg) { + return true + } + if strings.Contains(errCode, "missing_required_parameter") && strings.Contains(errMsgLower, "instructions") { + return true + } + if strings.Contains(errType, "invalid_request") && strings.Contains(errMsgLower, "instructions") && strings.Contains(errMsgLower, "required") { + return true + } + + return false } // GenerateSessionHash generates a sticky-session hash for OpenAI requests. @@ -228,7 +882,7 @@ func NewOpenAIGatewayService( // 1. Header: session_id // 2. Header: conversation_id // 3. Body: prompt_cache_key (opencode) -func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, reqBody map[string]any) string { +func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte) string { if c == nil { return "" } @@ -237,17 +891,35 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, reqBody map[s if sessionID == "" { sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) } - if sessionID == "" && reqBody != nil { - if v, ok := reqBody["prompt_cache_key"].(string); ok { - sessionID = strings.TrimSpace(v) - } + if sessionID == "" && len(body) > 0 { + sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) } if sessionID == "" { return "" } - hash := sha256.Sum256([]byte(sessionID)) - return hex.EncodeToString(hash[:]) + currentHash, legacyHash := deriveOpenAISessionHashes(sessionID) + attachOpenAILegacySessionHashToGin(c, legacyHash) + return currentHash +} + +// GenerateSessionHashWithFallback 先按常规信号生成会话哈希; +// 当未携带 session_id/conversation_id/prompt_cache_key 时,使用 fallbackSeed 生成稳定哈希。 +// 该方法用于 WS ingress,避免会话信号缺失时发生跨账号漂移。 +func (s *OpenAIGatewayService) GenerateSessionHashWithFallback(c *gin.Context, body []byte, fallbackSeed string) string { + sessionHash := s.GenerateSessionHash(c, body) + if sessionHash != "" { + return sessionHash + } + + seed := strings.TrimSpace(fallbackSeed) + if seed == "" { + return "" + } + + currentHash, legacyHash := deriveOpenAISessionHashes(seed) + attachOpenAILegacySessionHashToGin(c, legacyHash) + return currentHash } // BindStickySession sets session -> account binding with standard TTL. @@ -255,7 +927,11 @@ func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *i if sessionHash == "" || accountID <= 0 { return nil } - return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, accountID, openaiStickySessionTTL) + ttl := openaiStickySessionTTL + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 { + ttl = time.Duration(s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second + } + return s.setStickySessionAccountID(ctx, groupID, sessionHash, accountID, ttl) } // SelectAccount selects an OpenAI account with sticky session support @@ -271,11 +947,13 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. // SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。 func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { - cacheKey := "openai:" + sessionHash + return s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, 0) +} +func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) { // 1. 尝试粘性会话命中 // Try sticky session hit - if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs); account != nil { + if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil { return account, nil } @@ -300,7 +978,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C // 4. 设置粘性会话绑定 // Set sticky session binding if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, openaiStickySessionTTL) + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL) } return selected, nil @@ -311,14 +989,18 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C // // tryStickySessionHit attempts to get account from sticky session. // Returns account if hit and usable; clears session and returns nil if account is unavailable. -func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, cacheKey, requestedModel string, excludedIDs map[int64]struct{}) *Account { +func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) *Account { if sessionHash == "" { return nil } - accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey) - if err != nil || accountID <= 0 { - return nil + accountID := stickyAccountID + if accountID <= 0 { + var err error + accountID, err = s.getStickySessionAccountID(ctx, groupID, sessionHash) + if err != nil || accountID <= 0 { + return nil + } } if _, excluded := excludedIDs[accountID]; excluded { @@ -333,7 +1015,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // 检查账号是否需要清理粘性会话 // Check if sticky session should be cleared if shouldClearStickySession(account, requestedModel) { - _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey) + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) return nil } @@ -348,7 +1030,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // 刷新会话 TTL 并返回账号 // Refresh session TTL and return account - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, openaiStickySessionTTL) + _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) return account } @@ -434,12 +1116,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex cfg := s.schedulingConfig() var stickyAccountID int64 if sessionHash != "" && s.cache != nil { - if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash); err == nil { + if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil { stickyAccountID = accountID } } if s.concurrencyService == nil || !cfg.LoadBatchEnabled { - account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) + account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID) if err != nil { return nil, err } @@ -494,19 +1176,19 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex // ============ Layer 1: Sticky session ============ if sessionHash != "" { - accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) - if err == nil && accountID > 0 && !isExcluded(accountID) { + accountID := stickyAccountID + if accountID > 0 && !isExcluded(accountID) { account, err := s.getSchedulableAccount(ctx, accountID) if err == nil { clearSticky := shouldClearStickySession(account, requestedModel) if clearSticky { - _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) } if !clearSticky && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL) + _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) return &AccountSelectionResult{ Account: account, Acquired: true, @@ -570,7 +1252,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, acc.ID, openaiStickySessionTTL) + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, acc.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ Account: acc, @@ -620,7 +1302,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL) + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, item.account.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ Account: item.account, @@ -661,7 +1343,7 @@ func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, grou } else if groupID != nil { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI) } else { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, PlatformOpenAI) } if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) @@ -744,43 +1426,159 @@ func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, re func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) { startTime := time.Now() - // Parse request body once (avoid multiple parse/serialize cycles) - var reqBody map[string]any - if err := json.Unmarshal(body, &reqBody); err != nil { - return nil, fmt.Errorf("parse request: %w", err) + restrictionResult := s.detectCodexClientRestriction(c, account) + apiKeyID := getAPIKeyIDFromContext(c) + logCodexCLIOnlyDetection(ctx, c, account, apiKeyID, restrictionResult, body) + if restrictionResult.Enabled && !restrictionResult.Matched { + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "forbidden_error", + "message": "This account only allows Codex official clients", + }, + }) + return nil, errors.New("codex_cli_only restriction: only codex official clients are allowed") } - // Extract model and stream from parsed body - reqModel, _ := reqBody["model"].(string) - reqStream, _ := reqBody["stream"].(bool) - promptCacheKey := "" - if v, ok := reqBody["prompt_cache_key"].(string); ok { - promptCacheKey = strings.TrimSpace(v) + originalBody := body + reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) + originalModel := reqModel + + isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) + wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) + clientTransport := GetOpenAIClientTransport(c) + // 仅允许 WS 入站请求走 WS 上游,避免出现 HTTP -> WS 协议混用。 + wsDecision = resolveOpenAIWSDecisionByClientTransport(wsDecision, clientTransport) + if c != nil { + c.Set("openai_ws_transport_decision", string(wsDecision.Transport)) + c.Set("openai_ws_transport_reason", wsDecision.Reason) + } + if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocketV2 { + logOpenAIWSModeDebug( + "selected account_id=%d account_type=%s transport=%s reason=%s model=%s stream=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(wsDecision.Transport)), + normalizeOpenAIWSLogValue(wsDecision.Reason), + reqModel, + reqStream, + ) + } + // 当前仅支持 WSv2;WSv1 命中时直接返回错误,避免出现“配置可开但行为不确定”。 + if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocket { + if c != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": "OpenAI WSv1 is temporarily unsupported. Please enable responses_websockets_v2.", + }, + }) + } + return nil, errors.New("openai ws v1 is temporarily unsupported; use ws v2") + } + passthroughEnabled := account.IsOpenAIPassthroughEnabled() + if passthroughEnabled { + // 透传分支只需要轻量提取字段,避免热路径全量 Unmarshal。 + reasoningEffort := extractOpenAIReasoningEffortFromBody(body, reqModel) + return s.forwardOpenAIPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime) + } + + reqBody, err := getOpenAIRequestBodyMap(c, body) + if err != nil { + return nil, err + } + + if v, ok := reqBody["model"].(string); ok { + reqModel = v + originalModel = reqModel + } + if v, ok := reqBody["stream"].(bool); ok { + reqStream = v + } + if promptCacheKey == "" { + if v, ok := reqBody["prompt_cache_key"].(string); ok { + promptCacheKey = strings.TrimSpace(v) + } } // Track if body needs re-serialization bodyModified := false - originalModel := reqModel + // 单字段补丁快速路径:只要整个变更集最终可归约为同一路径的 set/delete,就避免全量 Marshal。 + patchDisabled := false + patchHasOp := false + patchDelete := false + patchPath := "" + var patchValue any + markPatchSet := func(path string, value any) { + if strings.TrimSpace(path) == "" { + patchDisabled = true + return + } + if patchDisabled { + return + } + if !patchHasOp { + patchHasOp = true + patchDelete = false + patchPath = path + patchValue = value + return + } + if patchDelete || patchPath != path { + patchDisabled = true + return + } + patchValue = value + } + markPatchDelete := func(path string) { + if strings.TrimSpace(path) == "" { + patchDisabled = true + return + } + if patchDisabled { + return + } + if !patchHasOp { + patchHasOp = true + patchDelete = true + patchPath = path + return + } + if !patchDelete || patchPath != path { + patchDisabled = true + } + } + disablePatch := func() { + patchDisabled = true + } - isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) + // 非透传模式下,保持历史行为:非 Codex CLI 请求在 instructions 为空时注入默认指令。 + if !isCodexCLI && isInstructionsEmpty(reqBody) { + if instructions := strings.TrimSpace(GetOpenCodeInstructions()); instructions != "" { + reqBody["instructions"] = instructions + bodyModified = true + markPatchSet("instructions", instructions) + } + } // 对所有请求执行模型映射(包含 Codex CLI)。 mappedModel := account.GetMappedModel(reqModel) if mappedModel != reqModel { - log.Printf("[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI) + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI) reqBody["model"] = mappedModel bodyModified = true + markPatchSet("model", mappedModel) } // 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。 if model, ok := reqBody["model"].(string); ok { normalizedModel := normalizeCodexModel(model) if normalizedModel != "" && normalizedModel != model { - log.Printf("[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", model, normalizedModel, account.Name, account.Type, isCodexCLI) reqBody["model"] = normalizedModel mappedModel = normalizedModel bodyModified = true + markPatchSet("model", normalizedModel) } } @@ -789,7 +1587,8 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" { reasoning["effort"] = "none" bodyModified = true - log.Printf("[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name) + markPatchSet("reasoning.effort", "none") + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name) } } @@ -797,6 +1596,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI) if codexResult.Modified { bodyModified = true + disablePatch() } if codexResult.NormalizedModel != "" { mappedModel = codexResult.NormalizedModel @@ -816,22 +1616,27 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco if account.Type == AccountTypeAPIKey { delete(reqBody, "max_output_tokens") bodyModified = true + markPatchDelete("max_output_tokens") } case PlatformAnthropic: // For Anthropic (Claude), convert to max_tokens delete(reqBody, "max_output_tokens") + markPatchDelete("max_output_tokens") if _, hasMaxTokens := reqBody["max_tokens"]; !hasMaxTokens { reqBody["max_tokens"] = maxOutputTokens + disablePatch() } bodyModified = true case PlatformGemini: // For Gemini, remove (will be handled by Gemini-specific transform) delete(reqBody, "max_output_tokens") bodyModified = true + markPatchDelete("max_output_tokens") default: // For unknown platforms, remove to be safe delete(reqBody, "max_output_tokens") bodyModified = true + markPatchDelete("max_output_tokens") } } @@ -840,24 +1645,51 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco if account.Type == AccountTypeAPIKey || account.Platform != PlatformOpenAI { delete(reqBody, "max_completion_tokens") bodyModified = true + markPatchDelete("max_completion_tokens") } } // Remove unsupported fields (not supported by upstream OpenAI API) - for _, unsupportedField := range []string{"prompt_cache_retention", "safety_identifier", "previous_response_id"} { + unsupportedFields := []string{"prompt_cache_retention", "safety_identifier"} + for _, unsupportedField := range unsupportedFields { if _, has := reqBody[unsupportedField]; has { delete(reqBody, unsupportedField) bodyModified = true + markPatchDelete(unsupportedField) } } } + // 仅在 WSv2 模式保留 previous_response_id,其他模式(HTTP/WSv1)统一过滤。 + // 注意:该规则同样适用于 Codex CLI 请求,避免 WSv1 向上游透传不支持字段。 + if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + if _, has := reqBody["previous_response_id"]; has { + delete(reqBody, "previous_response_id") + bodyModified = true + markPatchDelete("previous_response_id") + } + } + // Re-serialize body only if modified if bodyModified { - var err error - body, err = json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("serialize request body: %w", err) + serializedByPatch := false + if !patchDisabled && patchHasOp { + var patchErr error + if patchDelete { + body, patchErr = sjson.DeleteBytes(body, patchPath) + } else { + body, patchErr = sjson.SetBytes(body, patchPath, patchValue) + } + if patchErr == nil { + serializedByPatch = true + } + } + if !serializedByPatch { + var marshalErr error + body, marshalErr = json.Marshal(reqBody) + if marshalErr != nil { + return nil, fmt.Errorf("serialize request body: %w", marshalErr) + } } } @@ -867,6 +1699,184 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco return nil, err } + // Capture upstream request body for ops retry of this attempt. + setOpsUpstreamRequestBody(c, body) + + // 命中 WS 时仅走 WebSocket Mode;不再自动回退 HTTP。 + if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocketV2 { + wsReqBody := reqBody + if len(reqBody) > 0 { + wsReqBody = make(map[string]any, len(reqBody)) + for k, v := range reqBody { + wsReqBody[k] = v + } + } + _, hasPreviousResponseID := wsReqBody["previous_response_id"] + logOpenAIWSModeDebug( + "forward_start account_id=%d account_type=%s model=%s stream=%v has_previous_response_id=%v", + account.ID, + account.Type, + mappedModel, + reqStream, + hasPreviousResponseID, + ) + maxAttempts := openAIWSReconnectRetryLimit + 1 + wsAttempts := 0 + var wsResult *OpenAIForwardResult + var wsErr error + wsLastFailureReason := "" + wsPrevResponseRecoveryTried := false + recoverPrevResponseNotFound := func(attempt int) bool { + if wsPrevResponseRecoveryTried { + return false + } + previousResponseID := openAIWSPayloadString(wsReqBody, "previous_response_id") + if previousResponseID == "" { + logOpenAIWSModeInfo( + "reconnect_prev_response_recovery_skip account_id=%d attempt=%d reason=missing_previous_response_id previous_response_id_present=false", + account.ID, + attempt, + ) + return false + } + if HasFunctionCallOutput(wsReqBody) { + logOpenAIWSModeInfo( + "reconnect_prev_response_recovery_skip account_id=%d attempt=%d reason=has_function_call_output previous_response_id_present=true", + account.ID, + attempt, + ) + return false + } + delete(wsReqBody, "previous_response_id") + wsPrevResponseRecoveryTried = true + logOpenAIWSModeInfo( + "reconnect_prev_response_recovery account_id=%d attempt=%d action=drop_previous_response_id retry=1 previous_response_id=%s previous_response_id_kind=%s", + account.ID, + attempt, + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(ClassifyOpenAIPreviousResponseIDKind(previousResponseID)), + ) + return true + } + retryBudget := s.openAIWSRetryTotalBudget() + retryStartedAt := time.Now() + wsRetryLoop: + for attempt := 1; attempt <= maxAttempts; attempt++ { + wsAttempts = attempt + wsResult, wsErr = s.forwardOpenAIWSV2( + ctx, + c, + account, + wsReqBody, + token, + wsDecision, + isCodexCLI, + reqStream, + originalModel, + mappedModel, + startTime, + attempt, + wsLastFailureReason, + ) + if wsErr == nil { + break + } + if c != nil && c.Writer != nil && c.Writer.Written() { + break + } + + reason, retryable := classifyOpenAIWSReconnectReason(wsErr) + if reason != "" { + wsLastFailureReason = reason + } + // previous_response_not_found 说明续链锚点不可用: + // 对非 function_call_output 场景,允许一次“去掉 previous_response_id 后重放”。 + if reason == "previous_response_not_found" && recoverPrevResponseNotFound(attempt) { + continue + } + if retryable && attempt < maxAttempts { + backoff := s.openAIWSRetryBackoff(attempt) + if retryBudget > 0 && time.Since(retryStartedAt)+backoff > retryBudget { + s.recordOpenAIWSRetryExhausted() + logOpenAIWSModeInfo( + "reconnect_budget_exhausted account_id=%d attempts=%d max_retries=%d reason=%s elapsed_ms=%d budget_ms=%d", + account.ID, + attempt, + openAIWSReconnectRetryLimit, + normalizeOpenAIWSLogValue(reason), + time.Since(retryStartedAt).Milliseconds(), + retryBudget.Milliseconds(), + ) + break + } + s.recordOpenAIWSRetryAttempt(backoff) + logOpenAIWSModeInfo( + "reconnect_retry account_id=%d retry=%d max_retries=%d reason=%s backoff_ms=%d", + account.ID, + attempt, + openAIWSReconnectRetryLimit, + normalizeOpenAIWSLogValue(reason), + backoff.Milliseconds(), + ) + if backoff > 0 { + timer := time.NewTimer(backoff) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + wsErr = wrapOpenAIWSFallback("retry_backoff_canceled", ctx.Err()) + break wsRetryLoop + case <-timer.C: + } + } + continue + } + if retryable { + s.recordOpenAIWSRetryExhausted() + logOpenAIWSModeInfo( + "reconnect_exhausted account_id=%d attempts=%d max_retries=%d reason=%s", + account.ID, + attempt, + openAIWSReconnectRetryLimit, + normalizeOpenAIWSLogValue(reason), + ) + } else if reason != "" { + s.recordOpenAIWSNonRetryableFastFallback() + logOpenAIWSModeInfo( + "reconnect_stop account_id=%d attempt=%d reason=%s", + account.ID, + attempt, + normalizeOpenAIWSLogValue(reason), + ) + } + break + } + if wsErr == nil { + firstTokenMs := int64(0) + hasFirstTokenMs := wsResult != nil && wsResult.FirstTokenMs != nil + if hasFirstTokenMs { + firstTokenMs = int64(*wsResult.FirstTokenMs) + } + requestID := "" + if wsResult != nil { + requestID = strings.TrimSpace(wsResult.RequestID) + } + logOpenAIWSModeDebug( + "forward_succeeded account_id=%d request_id=%s stream=%v has_first_token_ms=%v first_token_ms=%d ws_attempts=%d", + account.ID, + requestID, + reqStream, + hasFirstTokenMs, + firstTokenMs, + wsAttempts, + ) + return wsResult, nil + } + s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr) + return nil, wsErr + } + // Build upstream request upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) if err != nil { @@ -879,13 +1889,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco proxyURL = account.Proxy.URL() } - // Capture upstream request body for ops retry of this attempt. - if c != nil { - c.Set(OpsUpstreamRequestBodyKey, string(body)) - } - // Send request + upstreamStart := time.Now() resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) if err != nil { // Ensure the client receives an error response (handlers assume Forward writes on non-failover errors). safeErr := sanitizeUpstreamErrorMessage(err.Error()) @@ -939,7 +1946,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco s.handleFailoverSideEffects(ctx, resp, account) return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } - return s.handleErrorResponse(ctx, resp, c, account) + return s.handleErrorResponse(ctx, resp, c, account, body) } // Handle normal response @@ -966,6 +1973,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } } + if usage == nil { + usage = &OpenAIUsage{} + } + reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel) return &OpenAIForwardResult{ @@ -974,11 +1985,576 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco Model: originalModel, ReasoningEffort: reasoningEffort, Stream: reqStream, + OpenAIWSMode: false, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, }, nil } +func (s *OpenAIGatewayService) forwardOpenAIPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + reqModel string, + reasoningEffort *string, + reqStream bool, + startTime time.Time, +) (*OpenAIForwardResult, error) { + if account != nil && account.Type == AccountTypeOAuth { + if rejectReason := detectOpenAIPassthroughInstructionsRejectReason(reqModel, body); rejectReason != "" { + rejectMsg := "OpenAI codex passthrough requires a non-empty instructions field" + setOpsUpstreamError(c, http.StatusForbidden, rejectMsg, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: http.StatusForbidden, + Passthrough: true, + Kind: "request_error", + Message: rejectMsg, + Detail: rejectReason, + }) + logOpenAIPassthroughInstructionsRejected(ctx, c, account, reqModel, rejectReason, body) + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "forbidden_error", + "message": rejectMsg, + }, + }) + return nil, fmt.Errorf("openai passthrough rejected before upstream: %s", rejectReason) + } + + normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body) + if err != nil { + return nil, err + } + if normalized { + body = normalizedBody + reqStream = true + } + } + + logger.LegacyPrintf("service.openai_gateway", + "[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v", + account.ID, + account.Name, + account.Type, + reqModel, + reqStream, + ) + if reqStream && c != nil && c.Request != nil { + if timeoutHeaders := collectOpenAIPassthroughTimeoutHeaders(c.Request.Header); len(timeoutHeaders) > 0 { + streamWarnLogger := logger.FromContext(ctx).With( + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", account.ID), + zap.Strings("timeout_headers", timeoutHeaders), + ) + if s.isOpenAIPassthroughTimeoutHeadersAllowed() { + streamWarnLogger.Warn("OpenAI passthrough 透传请求包含超时相关请求头,且当前配置为放行,可能导致上游提前断流") + } else { + streamWarnLogger.Warn("OpenAI passthrough 检测到超时相关请求头,将按配置过滤以降低断流风险") + } + } + } + + // Get access token + token, _, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + + upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(ctx, c, account, body, token) + if err != nil { + return nil, err + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + setOpsUpstreamRequestBody(c, body) + if c != nil { + c.Set("openai_passthrough", true) + } + + upstreamStart := time.Now() + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Passthrough: true, + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + // 透传模式不做 failover(避免改变原始上游语义),按上游原样返回错误响应。 + return nil, s.handleErrorResponsePassthrough(ctx, resp, c, account, body) + } + + var usage *OpenAIUsage + var firstTokenMs *int + if reqStream { + result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime) + if err != nil { + return nil, err + } + usage = result.usage + firstTokenMs = result.firstTokenMs + } else { + usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c) + if err != nil { + return nil, err + } + } + + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { + s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) + } + + if usage == nil { + usage = &OpenAIUsage{} + } + + return &OpenAIForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: reqModel, + ReasoningEffort: reasoningEffort, + Stream: reqStream, + OpenAIWSMode: false, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + +func logOpenAIPassthroughInstructionsRejected( + ctx context.Context, + c *gin.Context, + account *Account, + reqModel string, + rejectReason string, + body []byte, +) { + if ctx == nil { + ctx = context.Background() + } + accountID := int64(0) + accountName := "" + accountType := "" + if account != nil { + accountID = account.ID + accountName = strings.TrimSpace(account.Name) + accountType = strings.TrimSpace(string(account.Type)) + } + fields := []zap.Field{ + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", accountID), + zap.String("account_name", accountName), + zap.String("account_type", accountType), + zap.String("request_model", strings.TrimSpace(reqModel)), + zap.String("reject_reason", strings.TrimSpace(rejectReason)), + } + fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, body) + logger.FromContext(ctx).With(fields...).Warn("OpenAI passthrough 本地拦截:Codex 请求缺少有效 instructions") +} + +func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + token string, +) (*http.Request, error) { + targetURL := openaiPlatformAPIURL + switch account.Type { + case AccountTypeOAuth: + targetURL = chatgptCodexURL + case AccountTypeAPIKey: + baseURL := account.GetOpenAIBaseURL() + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = buildOpenAIResponsesURL(validatedURL) + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + // 透传客户端请求头(安全白名单)。 + allowTimeoutHeaders := s.isOpenAIPassthroughTimeoutHeadersAllowed() + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + lower := strings.ToLower(strings.TrimSpace(key)) + if !isOpenAIPassthroughAllowedRequestHeader(lower, allowTimeoutHeaders) { + continue + } + for _, v := range values { + req.Header.Add(key, v) + } + } + } + + // 覆盖入站鉴权残留,并注入上游认证 + req.Header.Del("authorization") + req.Header.Del("x-api-key") + req.Header.Del("x-goog-api-key") + req.Header.Set("authorization", "Bearer "+token) + + // OAuth 透传到 ChatGPT internal API 时补齐必要头。 + if account.Type == AccountTypeOAuth { + promptCacheKey := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + req.Host = "chatgpt.com" + if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { + req.Header.Set("chatgpt-account-id", chatgptAccountID) + } + if req.Header.Get("accept") == "" { + req.Header.Set("accept", "text/event-stream") + } + if req.Header.Get("OpenAI-Beta") == "" { + req.Header.Set("OpenAI-Beta", "responses=experimental") + } + if req.Header.Get("originator") == "" { + req.Header.Set("originator", "codex_cli_rs") + } + if promptCacheKey != "" { + if req.Header.Get("conversation_id") == "" { + req.Header.Set("conversation_id", promptCacheKey) + } + if req.Header.Get("session_id") == "" { + req.Header.Set("session_id", promptCacheKey) + } + } + } + + // 透传模式也支持账户自定义 User-Agent 与 ForceCodexCLI 兜底。 + customUA := account.GetOpenAIUserAgent() + if customUA != "" { + req.Header.Set("user-agent", customUA) + } + if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { + req.Header.Set("user-agent", codexCLIUserAgent) + } + // OAuth 安全透传:对非 Codex UA 统一兜底,降低被上游风控拦截概率。 + if account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(req.Header.Get("user-agent")) { + req.Header.Set("user-agent", codexCLIUserAgent) + } + + if req.Header.Get("content-type") == "" { + req.Header.Set("content-type", "application/json") + } + + return req, nil +} + +func (s *OpenAIGatewayService) handleErrorResponsePassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + requestBody []byte, +) error { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(body), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + UpstreamResponseBody: upstreamDetail, + }) + + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, body) + + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) +} + +func isOpenAIPassthroughAllowedRequestHeader(lowerKey string, allowTimeoutHeaders bool) bool { + if lowerKey == "" { + return false + } + if isOpenAIPassthroughTimeoutHeader(lowerKey) { + return allowTimeoutHeaders + } + return openaiPassthroughAllowedHeaders[lowerKey] +} + +func isOpenAIPassthroughTimeoutHeader(lowerKey string) bool { + switch lowerKey { + case "x-stainless-timeout", "x-stainless-read-timeout", "x-stainless-connect-timeout", "x-request-timeout", "request-timeout", "grpc-timeout": + return true + default: + return false + } +} + +func (s *OpenAIGatewayService) isOpenAIPassthroughTimeoutHeadersAllowed() bool { + return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIPassthroughAllowTimeoutHeaders +} + +func collectOpenAIPassthroughTimeoutHeaders(h http.Header) []string { + if h == nil { + return nil + } + var matched []string + for key, values := range h { + lowerKey := strings.ToLower(strings.TrimSpace(key)) + if isOpenAIPassthroughTimeoutHeader(lowerKey) { + entry := lowerKey + if len(values) > 0 { + entry = fmt.Sprintf("%s=%s", lowerKey, strings.Join(values, "|")) + } + matched = append(matched, entry) + } + } + sort.Strings(matched) + return matched +} + +type openaiStreamingResultPassthrough struct { + usage *OpenAIUsage + firstTokenMs *int +} + +func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + startTime time.Time, +) (*openaiStreamingResultPassthrough, error) { + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + + // SSE headers + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + if v := resp.Header.Get("x-request-id"); v != "" { + c.Header("x-request-id", v) + } + + w := c.Writer + flusher, ok := w.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + usage := &OpenAIUsage{} + var firstTokenMs *int + clientDisconnected := false + sawDone := false + upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id")) + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) + defer putSSEScannerBuf64K(scanBuf) + + for scanner.Scan() { + line := scanner.Text() + if data, ok := extractOpenAISSEDataLine(line); ok { + dataBytes := []byte(data) + trimmedData := strings.TrimSpace(data) + if trimmedData == "[DONE]" { + sawDone = true + } + if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsageBytes(dataBytes, usage) + } + + if !clientDisconnected { + if _, err := fmt.Fprintln(w, line); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) + } else { + flusher.Flush() + } + } + } + if err := scanner.Err(); err != nil { + if clientDisconnected { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err) + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + logger.LegacyPrintf("service.openai_gateway", + "[OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v", + account.ID, + upstreamRequestID, + err, + ctx.Err(), + ) + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil + } + if errors.Is(err, bufio.ErrTooLong) { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err) + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err + } + logger.LegacyPrintf("service.openai_gateway", + "[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v", + account.ID, + upstreamRequestID, + err, + ) + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) + } + if !clientDisconnected && !sawDone && ctx.Err() == nil { + logger.FromContext(ctx).With( + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", account.ID), + zap.String("upstream_request_id", upstreamRequestID), + ).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流") + } + + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil +} + +func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, +) (*OpenAIUsage, error) { + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } + return nil, err + } + + usage := &OpenAIUsage{} + usageParsed := false + if len(body) > 0 { + if parsedUsage, ok := extractOpenAIUsageFromJSONBytes(body); ok { + *usage = parsedUsage + usageParsed = true + } + } + if !usageParsed { + // 兜底:尝试从 SSE 文本中解析 usage + usage = s.parseSSEUsageFromBody(string(body)) + } + + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, body) + return usage, nil +} + +func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) { + if dst == nil || src == nil { + return + } + if filter != nil { + responseheaders.WriteFilteredHeaders(dst, src, filter) + } else { + // 兜底:尽量保留最基础的 content-type + if v := strings.TrimSpace(src.Get("Content-Type")); v != "" { + dst.Set("Content-Type", v) + } + } + // 透传模式强制放行 x-codex-* 响应头(若上游返回)。 + // 注意:真实 http.Response.Header 的 key 一般会被 canonicalize;但为了兼容测试/自建响应, + // 这里用 EqualFold 做一次大小写不敏感的查找。 + getCaseInsensitiveValues := func(h http.Header, want string) []string { + if h == nil { + return nil + } + for k, vals := range h { + if strings.EqualFold(k, want) { + return vals + } + } + return nil + } + + for _, rawKey := range []string{ + "x-codex-primary-used-percent", + "x-codex-primary-reset-after-seconds", + "x-codex-primary-window-minutes", + "x-codex-secondary-used-percent", + "x-codex-secondary-reset-after-seconds", + "x-codex-secondary-window-minutes", + "x-codex-primary-over-secondary-limit-percent", + } { + vals := getCaseInsensitiveValues(src, rawKey) + if len(vals) == 0 { + continue + } + key := http.CanonicalHeaderKey(rawKey) + dst.Del(key) + for _, v := range vals { + dst.Add(key, v) + } + } +} + func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) { // Determine target URL based on account type var targetURL string @@ -996,7 +2572,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. if err != nil { return nil, err } - targetURL = validatedURL + "/responses" + targetURL = buildOpenAIResponsesURL(validatedURL) } default: targetURL = openaiPlatformAPIURL @@ -1050,6 +2626,12 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. req.Header.Set("user-agent", customUA) } + // 若开启 ForceCodexCLI,则强制将上游 User-Agent 伪装为 Codex CLI。 + // 用于网关未透传/改写 User-Agent 时,仍能命中 Codex 侧识别逻辑。 + if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { + req.Header.Set("user-agent", codexCLIUserAgent) + } + // Ensure required headers exist if req.Header.Get("content-type") == "" { req.Header.Set("content-type", "application/json") @@ -1058,7 +2640,13 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. return req, nil } -func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*OpenAIForwardResult, error) { +func (s *OpenAIGatewayService) handleErrorResponse( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + requestBody []byte, +) (*OpenAIForwardResult, error) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) @@ -1072,9 +2660,10 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht upstreamDetail = truncateString(string(body), maxBytes) } setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( + logger.LegacyPrintf("service.openai_gateway", "OpenAI upstream error %d (account=%d platform=%s type=%s): %s", resp.StatusCode, account.ID, @@ -1202,8 +2791,8 @@ type openaiStreamingResult struct { } func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) { - if s.cfg != nil { - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) } // Set SSE response headers @@ -1222,6 +2811,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if !ok { return nil, errors.New("streaming not supported") } + bufferedWriter := bufio.NewWriterSize(w, 4*1024) + flushBuffered := func() error { + if err := bufferedWriter.Flush(); err != nil { + return err + } + flusher.Flush() + return nil + } usage := &OpenAIUsage{} var firstTokenMs *int @@ -1230,38 +2827,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) - - type scanEvent struct { - line string - err error - } - // 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理 - events := make(chan scanEvent, 16) - done := make(chan struct{}) - sendEvent := func(ev scanEvent) bool { - select { - case events <- ev: - return true - case <-done: - return false - } - } - var lastReadAt int64 - atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { - defer close(events) - for scanner.Scan() { - atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - if !sendEvent(scanEvent{line: scanner.Text()}) { - return - } - } - if err := scanner.Err(); err != nil { - _ = sendEvent(scanEvent{err: err}) - } - }() - defer close(done) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) streamInterval := time.Duration(0) if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { @@ -1305,95 +2872,178 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp return } errorEventSent = true - payload := map[string]any{ - "type": "error", - "sequence_number": 0, - "error": map[string]any{ - "type": "upstream_error", - "message": reason, - "code": reason, - }, + payload := `{"type":"error","sequence_number":0,"error":{"type":"upstream_error","message":` + strconv.Quote(reason) + `,"code":` + strconv.Quote(reason) + `}}` + if err := flushBuffered(); err != nil { + clientDisconnected = true + return } - if b, err := json.Marshal(payload); err == nil { - _, _ = fmt.Fprintf(w, "data: %s\n\n", b) - flusher.Flush() + if _, err := bufferedWriter.WriteString("data: " + payload + "\n\n"); err != nil { + clientDisconnected = true + return + } + if err := flushBuffered(); err != nil { + clientDisconnected = true } } needModelReplace := originalModel != mappedModel + resultWithUsage := func() *openaiStreamingResult { + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs} + } + finalizeStream := func() (*openaiStreamingResult, error) { + if !clientDisconnected { + if err := flushBuffered(); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage") + } + } + return resultWithUsage(), nil + } + handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) { + if scanErr == nil { + return nil, nil, false + } + // 客户端断开/取消请求时,上游读取往往会返回 context canceled。 + // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 + if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) { + logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage") + return resultWithUsage(), nil, true + } + // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage + if clientDisconnected { + logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", scanErr) + return resultWithUsage(), nil, true + } + if errors.Is(scanErr, bufio.ErrTooLong) { + logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr) + sendErrorEvent("response_too_large") + return resultWithUsage(), scanErr, true + } + sendErrorEvent("stream_read_error") + return resultWithUsage(), fmt.Errorf("stream read error: %w", scanErr), true + } + processSSELine := func(line string, queueDrained bool) { + lastDataAt = time.Now() + + // Extract data from SSE line (supports both "data: " and "data:" formats) + if data, ok := extractOpenAISSEDataLine(line); ok { + + // Replace model in response if needed. + // Fast path: most events do not contain model field values. + if needModelReplace && mappedModel != "" && strings.Contains(data, mappedModel) { + line = s.replaceModelInSSELine(line, mappedModel, originalModel) + } + + dataBytes := []byte(data) + + // Correct Codex tool calls if needed (apply_patch -> edit, etc.) + if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected { + dataBytes = correctedData + data = string(correctedData) + line = "data: " + data + } + + // 写入客户端(客户端断开后继续 drain 上游) + if !clientDisconnected { + shouldFlush := queueDrained + if firstTokenMs == nil && data != "" && data != "[DONE]" { + // 保证首个 token 事件尽快出站,避免影响 TTFT。 + shouldFlush = true + } + if _, err := bufferedWriter.WriteString(line); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + } else if _, err := bufferedWriter.WriteString("\n"); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + } else if shouldFlush { + if err := flushBuffered(); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing") + } + } + } + + // Record first token time + if firstTokenMs == nil && data != "" && data != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsageBytes(dataBytes, usage) + return + } + + // Forward non-data lines as-is + if !clientDisconnected { + if _, err := bufferedWriter.WriteString(line); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + } else if _, err := bufferedWriter.WriteString("\n"); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + } else if queueDrained { + if err := flushBuffered(); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing") + } + } + } + } + + // 无超时/无 keepalive 的常见路径走同步扫描,减少 goroutine 与 channel 开销。 + if streamInterval <= 0 && keepaliveInterval <= 0 { + defer putSSEScannerBuf64K(scanBuf) + for scanner.Scan() { + processSSELine(scanner.Text(), true) + } + if result, err, done := handleScanErr(scanner.Err()); done { + return result, err + } + return finalizeStream() + } + + type scanEvent struct { + line string + err error + } + // 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理 + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }(scanBuf) + defer close(done) for { select { case ev, ok := <-events: if !ok { - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + return finalizeStream() } - if ev.err != nil { - // 客户端断开/取消请求时,上游读取往往会返回 context canceled。 - // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 - if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { - log.Printf("Context canceled during streaming, returning collected usage") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil - } - // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage - if clientDisconnected { - log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err) - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil - } - if errors.Is(ev.err, bufio.ErrTooLong) { - log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) - sendErrorEvent("response_too_large") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err - } - sendErrorEvent("stream_read_error") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) - } - - line := ev.line - lastDataAt = time.Now() - - // Extract data from SSE line (supports both "data: " and "data:" formats) - if openaiSSEDataRe.MatchString(line) { - data := openaiSSEDataRe.ReplaceAllString(line, "") - - // Replace model in response if needed - if needModelReplace { - line = s.replaceModelInSSELine(line, mappedModel, originalModel) - } - - // Correct Codex tool calls if needed (apply_patch -> edit, etc.) - if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected { - data = correctedData - line = "data: " + correctedData - } - - // 写入客户端(客户端断开后继续 drain 上游) - if !clientDisconnected { - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - clientDisconnected = true - log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") - } else { - flusher.Flush() - } - } - - // Record first token time - if firstTokenMs == nil && data != "" && data != "[DONE]" { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms - } - s.parseSSEUsage(data, usage) - } else { - // Forward non-data lines as-is - if !clientDisconnected { - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - clientDisconnected = true - log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") - } else { - flusher.Flush() - } - } + if result, err, done := handleScanErr(ev.err); done { + return result, err } + processSSELine(ev.line, len(events) == 0) case <-intervalCh: lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) @@ -1401,16 +3051,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp continue } if clientDisconnected { - log.Printf("Upstream timeout after client disconnect, returning collected usage") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage") + return resultWithUsage(), nil } - log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) + logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) // 处理流超时,可能标记账户为临时不可调度或错误状态 if s.rateLimitService != nil { s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) } sendErrorEvent("stream_timeout") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + return resultWithUsage(), fmt.Errorf("stream data interval timeout") case <-keepaliveCh: if clientDisconnected { @@ -1419,51 +3069,61 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if time.Since(lastDataAt) < keepaliveInterval { continue } - if _, err := fmt.Fprint(w, ":\n\n"); err != nil { + if _, err := bufferedWriter.WriteString(":\n\n"); err != nil { clientDisconnected = true - log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") continue } - flusher.Flush() + if err := flushBuffered(); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during keepalive flush, continuing to drain upstream for billing") + } } } } +// extractOpenAISSEDataLine 低开销提取 SSE `data:` 行内容。 +// 兼容 `data: xxx` 与 `data:xxx` 两种格式。 +func extractOpenAISSEDataLine(line string) (string, bool) { + if !strings.HasPrefix(line, "data:") { + return "", false + } + start := len("data:") + for start < len(line) { + if line[start] != ' ' && line[start] != ' ' { + break + } + start++ + } + return line[start:], true +} + func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string { - if !openaiSSEDataRe.MatchString(line) { + data, ok := extractOpenAISSEDataLine(line) + if !ok { return line } - data := openaiSSEDataRe.ReplaceAllString(line, "") if data == "" || data == "[DONE]" { return line } - var event map[string]any - if err := json.Unmarshal([]byte(data), &event); err != nil { - return line - } - - // Replace model in response - if m, ok := event["model"].(string); ok && m == fromModel { - event["model"] = toModel - newData, err := json.Marshal(event) + // 使用 gjson 精确检查 model 字段,避免全量 JSON 反序列化 + if m := gjson.Get(data, "model"); m.Exists() && m.Str == fromModel { + newData, err := sjson.Set(data, "model", toModel) if err != nil { return line } - return "data: " + string(newData) + return "data: " + newData } - // Check nested response - if response, ok := event["response"].(map[string]any); ok { - if m, ok := response["model"].(string); ok && m == fromModel { - response["model"] = toModel - newData, err := json.Marshal(event) - if err != nil { - return line - } - return "data: " + string(newData) + // 检查嵌套的 response.model 字段 + if m := gjson.Get(data, "response.model"); m.Exists() && m.Str == fromModel { + newData, err := sjson.Set(data, "response.model", toModel) + if err != nil { + return line } + return "data: " + newData } return line @@ -1475,39 +3135,64 @@ func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byt return body } - bodyStr := string(body) - corrected, changed := s.toolCorrector.CorrectToolCallsInSSEData(bodyStr) + corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(body) if changed { - return []byte(corrected) + return corrected } return body } func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) { - // Parse response.completed event for usage (OpenAI Responses format) - var event struct { - Type string `json:"type"` - Response struct { - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - InputTokenDetails struct { - CachedTokens int `json:"cached_tokens"` - } `json:"input_tokens_details"` - } `json:"usage"` - } `json:"response"` + s.parseSSEUsageBytes([]byte(data), usage) +} + +func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsage) { + if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) { + return + } + // 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。 + if len(data) < 80 || !bytes.Contains(data, []byte(`"response.completed"`)) { + return + } + if gjson.GetBytes(data, "type").String() != "response.completed" { + return } - if json.Unmarshal([]byte(data), &event) == nil && event.Type == "response.completed" { - usage.InputTokens = event.Response.Usage.InputTokens - usage.OutputTokens = event.Response.Usage.OutputTokens - usage.CacheReadInputTokens = event.Response.Usage.InputTokenDetails.CachedTokens + usage.InputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens").Int()) + usage.OutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens").Int()) + usage.CacheReadInputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens_details.cached_tokens").Int()) +} + +func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) { + if len(body) == 0 || !gjson.ValidBytes(body) { + return OpenAIUsage{}, false } + values := gjson.GetManyBytes( + body, + "usage.input_tokens", + "usage.output_tokens", + "usage.input_tokens_details.cached_tokens", + ) + return OpenAIUsage{ + InputTokens: int(values[0].Int()), + OutputTokens: int(values[1].Int()), + CacheReadInputTokens: int(values[2].Int()), + }, true } func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) { - body, err := io.ReadAll(resp.Body) + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } return nil, err } @@ -1518,32 +3203,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r } } - // Parse usage - var response struct { - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - InputTokenDetails struct { - CachedTokens int `json:"cached_tokens"` - } `json:"input_tokens_details"` - } `json:"usage"` - } - if err := json.Unmarshal(body, &response); err != nil { - return nil, fmt.Errorf("parse response: %w", err) - } - - usage := &OpenAIUsage{ - InputTokens: response.Usage.InputTokens, - OutputTokens: response.Usage.OutputTokens, - CacheReadInputTokens: response.Usage.InputTokenDetails.CachedTokens, + usageValue, usageOK := extractOpenAIUsageFromJSONBytes(body) + if !usageOK { + return nil, fmt.Errorf("parse response: invalid json response") } + usage := &usageValue // Replace model in response if needed if originalModel != mappedModel { body = s.replaceModelInResponseBody(body, mappedModel, originalModel) } - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := "application/json" if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled { @@ -1568,19 +3239,8 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin. usage := &OpenAIUsage{} if ok { - var response struct { - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - InputTokenDetails struct { - CachedTokens int `json:"cached_tokens"` - } `json:"input_tokens_details"` - } `json:"usage"` - } - if err := json.Unmarshal(finalResponse, &response); err == nil { - usage.InputTokens = response.Usage.InputTokens - usage.OutputTokens = response.Usage.OutputTokens - usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens + if parsedUsage, parsed := extractOpenAIUsageFromJSONBytes(finalResponse); parsed { + *usage = parsedUsage } body = finalResponse if originalModel != mappedModel { @@ -1596,7 +3256,7 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin. body = []byte(bodyText) } - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := "application/json; charset=utf-8" if !ok { @@ -1613,23 +3273,17 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin. func extractCodexFinalResponse(body string) ([]byte, bool) { lines := strings.Split(body, "\n") for _, line := range lines { - if !openaiSSEDataRe.MatchString(line) { + data, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - data := openaiSSEDataRe.ReplaceAllString(line, "") if data == "" || data == "[DONE]" { continue } - var event struct { - Type string `json:"type"` - Response json.RawMessage `json:"response"` - } - if json.Unmarshal([]byte(data), &event) != nil { - continue - } - if event.Type == "response.done" || event.Type == "response.completed" { - if len(event.Response) > 0 { - return event.Response, true + eventType := gjson.Get(data, "type").String() + if eventType == "response.done" || eventType == "response.completed" { + if response := gjson.Get(data, "response"); response.Exists() && response.Type == gjson.JSON && response.Raw != "" { + return []byte(response.Raw), true } } } @@ -1640,14 +3294,14 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { usage := &OpenAIUsage{} lines := strings.Split(body, "\n") for _, line := range lines { - if !openaiSSEDataRe.MatchString(line) { + data, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - data := openaiSSEDataRe.ReplaceAllString(line, "") if data == "" || data == "[DONE]" { continue } - s.parseSSEUsage(data, usage) + s.parseSSEUsageBytes([]byte(data), usage) } return usage } @@ -1655,7 +3309,7 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string { lines := strings.Split(body, "\n") for i, line := range lines { - if !openaiSSEDataRe.MatchString(line) { + if _, ok := extractOpenAISSEDataLine(line); !ok { continue } lines[i] = s.replaceModelInSSELine(line, fromModel, toModel) @@ -1682,24 +3336,31 @@ func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, erro return normalized, nil } +// buildOpenAIResponsesURL 组装 OpenAI Responses 端点。 +// - base 以 /v1 结尾:追加 /responses +// - base 已是 /responses:原样返回 +// - 其他情况:追加 /v1/responses +func buildOpenAIResponsesURL(base string) string { + normalized := strings.TrimRight(strings.TrimSpace(base), "/") + if strings.HasSuffix(normalized, "/responses") { + return normalized + } + if strings.HasSuffix(normalized, "/v1") { + return normalized + "/responses" + } + return normalized + "/v1/responses" +} + func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { - var resp map[string]any - if err := json.Unmarshal(body, &resp); err != nil { - return body + // 使用 gjson/sjson 精确替换 model 字段,避免全量 JSON 反序列化 + if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel { + newBody, err := sjson.SetBytes(body, "model", toModel) + if err != nil { + return body + } + return newBody } - - model, ok := resp["model"].(string) - if !ok || model != fromModel { - return body - } - - resp["model"] = toModel - newBody, err := json.Marshal(resp) - if err != nil { - return body - } - - return newBody + return body } // OpenAIRecordUsageInput input for recording usage @@ -1779,6 +3440,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec AccountRateMultiplier: &accountRateMultiplier, BillingType: billingType, Stream: result.Stream, + OpenAIWSMode: result.OpenAIWSMode, DurationMs: &durationMs, FirstTokenMs: result.FirstTokenMs, CreatedAt: time.Now(), @@ -1803,7 +3465,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec inserted, err := s.usageLogRepo.Create(ctx, usageLog) if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } @@ -1826,10 +3488,18 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec // Update API key quota if applicable (only for balance mode with quota set) if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { - log.Printf("Update API key quota failed: %v", err) + logger.LegacyPrintf("service.openai_gateway", "Update API key quota failed: %v", err) } } + // Update API Key rate limit usage + if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil { + if err := input.APIKeyService.UpdateRateLimitUsage(ctx, apiKey.ID, cost.ActualCost); err != nil { + logger.LegacyPrintf("service.openai_gateway", "Update API key rate limit usage failed: %v", err) + } + s.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(apiKey.ID, cost.ActualCost) + } + // Schedule batch update for account last_used_at s.deferredService.ScheduleLastUsedUpdate(account.ID) @@ -1904,16 +3574,41 @@ func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot { return snapshot } -// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field -func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) { +func codexSnapshotBaseTime(snapshot *OpenAICodexUsageSnapshot, fallback time.Time) time.Time { if snapshot == nil { - return + return fallback + } + if snapshot.UpdatedAt == "" { + return fallback + } + base, err := time.Parse(time.RFC3339, snapshot.UpdatedAt) + if err != nil { + return fallback + } + return base +} + +func codexResetAtRFC3339(base time.Time, resetAfterSeconds *int) *string { + if resetAfterSeconds == nil { + return nil + } + sec := *resetAfterSeconds + if sec < 0 { + sec = 0 + } + resetAt := base.Add(time.Duration(sec) * time.Second).Format(time.RFC3339) + return &resetAt +} + +func buildCodexUsageExtraUpdates(snapshot *OpenAICodexUsageSnapshot, fallbackNow time.Time) map[string]any { + if snapshot == nil { + return nil } - // Convert snapshot to map for merging into Extra + baseTime := codexSnapshotBaseTime(snapshot, fallbackNow) updates := make(map[string]any) - // Save raw primary/secondary fields for debugging/tracing + // 保存原始 primary/secondary 字段,便于排查问题 if snapshot.PrimaryUsedPercent != nil { updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent } @@ -1935,9 +3630,9 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc if snapshot.PrimaryOverSecondaryPercent != nil { updates["codex_primary_over_secondary_percent"] = *snapshot.PrimaryOverSecondaryPercent } - updates["codex_usage_updated_at"] = snapshot.UpdatedAt + updates["codex_usage_updated_at"] = baseTime.Format(time.RFC3339) - // Normalize to canonical 5h/7d fields + // 归一化到 5h/7d 规范字段 if normalized := snapshot.Normalize(); normalized != nil { if normalized.Used5hPercent != nil { updates["codex_5h_used_percent"] = *normalized.Used5hPercent @@ -1957,6 +3652,29 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc if normalized.Window7dMinutes != nil { updates["codex_7d_window_minutes"] = *normalized.Window7dMinutes } + if reset5hAt := codexResetAtRFC3339(baseTime, normalized.Reset5hSeconds); reset5hAt != nil { + updates["codex_5h_reset_at"] = *reset5hAt + } + if reset7dAt := codexResetAtRFC3339(baseTime, normalized.Reset7dSeconds); reset7dAt != nil { + updates["codex_7d_reset_at"] = *reset7dAt + } + } + + return updates +} + +// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field +func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) { + if snapshot == nil { + return + } + if s == nil || s.accountRepo == nil { + return + } + + updates := buildCodexUsageExtraUpdates(snapshot, time.Now()) + if len(updates) == 0 { + return } // Update account's Extra field asynchronously @@ -2013,6 +3731,106 @@ func deriveOpenAIReasoningEffortFromModel(model string) string { return normalizeOpenAIReasoningEffort(parts[len(parts)-1]) } +func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, promptCacheKey string) { + if len(body) == 0 { + return "", false, "" + } + + model = strings.TrimSpace(gjson.GetBytes(body, "model").String()) + stream = gjson.GetBytes(body, "stream").Bool() + promptCacheKey = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + return model, stream, promptCacheKey +} + +// normalizeOpenAIPassthroughOAuthBody 将透传 OAuth 请求体收敛为旧链路关键行为: +// 1) store=false 2) stream=true +func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) { + if len(body) == 0 { + return body, false, nil + } + + normalized := body + changed := false + + if store := gjson.GetBytes(normalized, "store"); !store.Exists() || store.Type != gjson.False { + next, err := sjson.SetBytes(normalized, "store", false) + if err != nil { + return body, false, fmt.Errorf("normalize passthrough body store=false: %w", err) + } + normalized = next + changed = true + } + + if stream := gjson.GetBytes(normalized, "stream"); !stream.Exists() || stream.Type != gjson.True { + next, err := sjson.SetBytes(normalized, "stream", true) + if err != nil { + return body, false, fmt.Errorf("normalize passthrough body stream=true: %w", err) + } + normalized = next + changed = true + } + + return normalized, changed, nil +} + +func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byte) string { + model := strings.ToLower(strings.TrimSpace(reqModel)) + if !strings.Contains(model, "codex") { + return "" + } + + instructions := gjson.GetBytes(body, "instructions") + if !instructions.Exists() { + return "instructions_missing" + } + if instructions.Type != gjson.String { + return "instructions_not_string" + } + if strings.TrimSpace(instructions.String()) == "" { + return "instructions_empty" + } + return "" +} + +func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string { + reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String()) + if reasoningEffort == "" { + reasoningEffort = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String()) + } + if reasoningEffort != "" { + normalized := normalizeOpenAIReasoningEffort(reasoningEffort) + if normalized == "" { + return nil + } + return &normalized + } + + value := deriveOpenAIReasoningEffortFromModel(requestedModel) + if value == "" { + return nil + } + return &value +} + +func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error) { + if c != nil { + if cached, ok := c.Get(OpenAIParsedRequestBodyKey); ok { + if reqBody, ok := cached.(map[string]any); ok && reqBody != nil { + return reqBody, nil + } + } + } + + var reqBody map[string]any + if err := json.Unmarshal(body, &reqBody); err != nil { + return nil, fmt.Errorf("parse request: %w", err) + } + if c != nil { + c.Set(OpenAIParsedRequestBodyKey, reqBody) + } + return reqBody, nil +} + func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string { if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present { if value == "" { diff --git a/backend/internal/service/openai_gateway_service_codex_cli_only_test.go b/backend/internal/service/openai_gateway_service_codex_cli_only_test.go new file mode 100644 index 00000000..d7c95ada --- /dev/null +++ b/backend/internal/service/openai_gateway_service_codex_cli_only_test.go @@ -0,0 +1,266 @@ +package service + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type stubCodexRestrictionDetector struct { + result CodexClientRestrictionDetectionResult +} + +func (s *stubCodexRestrictionDetector) Detect(_ *gin.Context, _ *Account) CodexClientRestrictionDetectionResult { + return s.result +} + +func TestOpenAIGatewayService_GetCodexClientRestrictionDetector(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("使用注入的 detector", func(t *testing.T) { + expected := &stubCodexRestrictionDetector{ + result: CodexClientRestrictionDetectionResult{Enabled: true, Matched: true, Reason: "stub"}, + } + svc := &OpenAIGatewayService{codexDetector: expected} + + got := svc.getCodexClientRestrictionDetector() + require.Same(t, expected, got) + }) + + t.Run("service 为 nil 时返回默认 detector", func(t *testing.T) { + var svc *OpenAIGatewayService + got := svc.getCodexClientRestrictionDetector() + require.NotNil(t, got) + }) + + t.Run("service 未注入 detector 时返回默认 detector", func(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: true}}} + got := svc.getCodexClientRestrictionDetector() + require.NotNil(t, got) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + c.Request.Header.Set("User-Agent", "curl/8.0") + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{"codex_cli_only": true}} + + result := got.Detect(c, account) + require.True(t, result.Enabled) + require.True(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonForceCodexCLI, result.Reason) + }) +} + +func TestGetAPIKeyIDFromContext(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("context 为 nil", func(t *testing.T) { + require.Equal(t, int64(0), getAPIKeyIDFromContext(nil)) + }) + + t.Run("上下文没有 api_key", func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + require.Equal(t, int64(0), getAPIKeyIDFromContext(c)) + }) + + t.Run("api_key 类型错误", func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Set("api_key", "not-api-key") + require.Equal(t, int64(0), getAPIKeyIDFromContext(c)) + }) + + t.Run("api_key 指针为空", func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + var k *APIKey + c.Set("api_key", k) + require.Equal(t, int64(0), getAPIKeyIDFromContext(c)) + }) + + t.Run("正常读取 api_key_id", func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Set("api_key", &APIKey{ID: 12345}) + require.Equal(t, int64(12345), getAPIKeyIDFromContext(c)) + }) +} + +func TestLogCodexCLIOnlyDetection_NilSafety(t *testing.T) { + // 不校验日志内容,仅保证在 nil 入参下不会 panic。 + require.NotPanics(t, func() { + logCodexCLIOnlyDetection(context.TODO(), nil, nil, 0, CodexClientRestrictionDetectionResult{Enabled: true, Matched: false, Reason: "test"}, nil) + logCodexCLIOnlyDetection(context.Background(), nil, nil, 0, CodexClientRestrictionDetectionResult{Enabled: false, Matched: false, Reason: "disabled"}, nil) + }) +} + +func TestLogCodexCLIOnlyDetection_OnlyLogsRejected(t *testing.T) { + logSink, restore := captureStructuredLog(t) + defer restore() + + account := &Account{ID: 1001} + logCodexCLIOnlyDetection(context.Background(), nil, account, 2002, CodexClientRestrictionDetectionResult{ + Enabled: true, + Matched: true, + Reason: CodexClientRestrictionReasonMatchedUA, + }, nil) + logCodexCLIOnlyDetection(context.Background(), nil, account, 2002, CodexClientRestrictionDetectionResult{ + Enabled: true, + Matched: false, + Reason: CodexClientRestrictionReasonNotMatchedUA, + }, nil) + + require.False(t, logSink.ContainsMessage("OpenAI codex_cli_only 允许官方客户端请求")) + require.True(t, logSink.ContainsMessage("OpenAI codex_cli_only 拒绝非官方客户端请求")) +} + +func TestLogCodexCLIOnlyDetection_RejectedIncludesRequestDetails(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown") + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("OpenAI-Beta", "assistants=v2") + + body := []byte(`{"model":"gpt-5.2","stream":false,"prompt_cache_key":"pc-123","access_token":"secret-token","input":[{"type":"text","text":"hello"}]}`) + account := &Account{ID: 1001} + logCodexCLIOnlyDetection(context.Background(), c, account, 2002, CodexClientRestrictionDetectionResult{ + Enabled: true, + Matched: false, + Reason: CodexClientRestrictionReasonNotMatchedUA, + }, body) + + require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown")) + require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.2")) + require.True(t, logSink.ContainsFieldValue("request_query", "trace=1")) + require.True(t, logSink.ContainsFieldValue("request_prompt_cache_key_sha256", hashSensitiveValueForLog("pc-123"))) + require.True(t, logSink.ContainsFieldValue("request_headers", "openai-beta")) + require.True(t, logSink.ContainsField("request_body_size")) + require.False(t, logSink.ContainsField("request_body_preview")) +} + +func TestLogOpenAIInstructionsRequiredDebug_LogsRequestDetails(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "curl/8.0") + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("OpenAI-Beta", "assistants=v2") + + body := []byte(`{"model":"gpt-5.1-codex","stream":false,"prompt_cache_key":"pc-abc","access_token":"secret-token","input":[{"type":"text","text":"hello"}]}`) + account := &Account{ID: 1001, Name: "codex max套餐"} + + logOpenAIInstructionsRequiredDebug( + context.Background(), + c, + account, + http.StatusBadRequest, + "Instructions are required", + body, + []byte(`{"error":{"message":"Instructions are required","type":"invalid_request_error","param":"instructions","code":"missing_required_parameter"}}`), + ) + + require.True(t, logSink.ContainsMessageAtLevel("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查", "warn")) + require.True(t, logSink.ContainsFieldValue("request_user_agent", "curl/8.0")) + require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.1-codex")) + require.True(t, logSink.ContainsFieldValue("request_query", "trace=1")) + require.True(t, logSink.ContainsFieldValue("account_name", "codex max套餐")) + require.True(t, logSink.ContainsFieldValue("request_headers", "openai-beta")) + require.True(t, logSink.ContainsField("request_body_size")) + require.False(t, logSink.ContainsField("request_body_preview")) +} + +func TestLogOpenAIInstructionsRequiredDebug_NonTargetErrorSkipped(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "curl/8.0") + body := []byte(`{"model":"gpt-5.1-codex","stream":false}`) + + logOpenAIInstructionsRequiredDebug( + context.Background(), + c, + &Account{ID: 1001}, + http.StatusForbidden, + "forbidden", + body, + []byte(`{"error":{"message":"forbidden"}}`), + ) + + require.False(t, logSink.ContainsMessage("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查")) +} + +func TestOpenAIGatewayService_Forward_LogsInstructionsRequiredDetails(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("OpenAI-Beta", "assistants=v2") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusBadRequest, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "x-request-id": []string{"rid-upstream"}, + }, + Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Missing required parameter: 'instructions'","type":"invalid_request_error","param":"instructions","code":"missing_required_parameter"}}`)), + }, + } + svc := &OpenAIGatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ForceCodexCLI: false}, + }, + httpUpstream: upstream, + } + account := &Account{ + ID: 1001, + Name: "codex max套餐", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{"api_key": "sk-test"}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + body := []byte(`{"model":"gpt-5.1-codex","stream":false,"input":[{"type":"text","text":"hello"}],"prompt_cache_key":"pc-forward","access_token":"secret-token"}`) + + _, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Equal(t, http.StatusBadGateway, rec.Code) + require.Contains(t, err.Error(), "upstream error: 400") + + require.True(t, logSink.ContainsMessageAtLevel("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查", "warn")) + require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.1.0")) + require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.1-codex")) + require.True(t, logSink.ContainsFieldValue("request_headers", "openai-beta")) + require.True(t, logSink.ContainsField("request_body_size")) + require.False(t, logSink.ContainsField("request_body_preview")) +} diff --git a/backend/internal/service/openai_gateway_service_codex_snapshot_test.go b/backend/internal/service/openai_gateway_service_codex_snapshot_test.go new file mode 100644 index 00000000..654dd4ca --- /dev/null +++ b/backend/internal/service/openai_gateway_service_codex_snapshot_test.go @@ -0,0 +1,192 @@ +package service + +import ( + "testing" + "time" +) + +func TestCodexSnapshotBaseTime(t *testing.T) { + fallback := time.Date(2026, 2, 20, 9, 0, 0, 0, time.UTC) + + t.Run("nil snapshot uses fallback", func(t *testing.T) { + got := codexSnapshotBaseTime(nil, fallback) + if !got.Equal(fallback) { + t.Fatalf("got %v, want fallback %v", got, fallback) + } + }) + + t.Run("empty updatedAt uses fallback", func(t *testing.T) { + got := codexSnapshotBaseTime(&OpenAICodexUsageSnapshot{}, fallback) + if !got.Equal(fallback) { + t.Fatalf("got %v, want fallback %v", got, fallback) + } + }) + + t.Run("valid updatedAt wins", func(t *testing.T) { + got := codexSnapshotBaseTime(&OpenAICodexUsageSnapshot{UpdatedAt: "2026-02-16T10:00:00Z"}, fallback) + want := time.Date(2026, 2, 16, 10, 0, 0, 0, time.UTC) + if !got.Equal(want) { + t.Fatalf("got %v, want %v", got, want) + } + }) + + t.Run("invalid updatedAt uses fallback", func(t *testing.T) { + got := codexSnapshotBaseTime(&OpenAICodexUsageSnapshot{UpdatedAt: "invalid"}, fallback) + if !got.Equal(fallback) { + t.Fatalf("got %v, want fallback %v", got, fallback) + } + }) +} + +func TestCodexResetAtRFC3339(t *testing.T) { + base := time.Date(2026, 2, 16, 10, 0, 0, 0, time.UTC) + + t.Run("nil reset returns nil", func(t *testing.T) { + if got := codexResetAtRFC3339(base, nil); got != nil { + t.Fatalf("expected nil, got %v", *got) + } + }) + + t.Run("positive seconds", func(t *testing.T) { + sec := 90 + got := codexResetAtRFC3339(base, &sec) + if got == nil { + t.Fatal("expected non-nil") + } + if *got != "2026-02-16T10:01:30Z" { + t.Fatalf("got %s, want %s", *got, "2026-02-16T10:01:30Z") + } + }) + + t.Run("negative seconds clamp to base", func(t *testing.T) { + sec := -3 + got := codexResetAtRFC3339(base, &sec) + if got == nil { + t.Fatal("expected non-nil") + } + if *got != "2026-02-16T10:00:00Z" { + t.Fatalf("got %s, want %s", *got, "2026-02-16T10:00:00Z") + } + }) +} + +func TestBuildCodexUsageExtraUpdates_UsesSnapshotUpdatedAt(t *testing.T) { + primaryUsed := 88.0 + primaryReset := 86400 + primaryWindow := 10080 + secondaryUsed := 12.0 + secondaryReset := 3600 + secondaryWindow := 300 + + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: &primaryUsed, + PrimaryResetAfterSeconds: &primaryReset, + PrimaryWindowMinutes: &primaryWindow, + SecondaryUsedPercent: &secondaryUsed, + SecondaryResetAfterSeconds: &secondaryReset, + SecondaryWindowMinutes: &secondaryWindow, + UpdatedAt: "2026-02-16T10:00:00Z", + } + + updates := buildCodexUsageExtraUpdates(snapshot, time.Date(2026, 2, 20, 8, 0, 0, 0, time.UTC)) + if updates == nil { + t.Fatal("expected non-nil updates") + } + + if got := updates["codex_usage_updated_at"]; got != "2026-02-16T10:00:00Z" { + t.Fatalf("codex_usage_updated_at = %v, want %s", got, "2026-02-16T10:00:00Z") + } + if got := updates["codex_5h_reset_at"]; got != "2026-02-16T11:00:00Z" { + t.Fatalf("codex_5h_reset_at = %v, want %s", got, "2026-02-16T11:00:00Z") + } + if got := updates["codex_7d_reset_at"]; got != "2026-02-17T10:00:00Z" { + t.Fatalf("codex_7d_reset_at = %v, want %s", got, "2026-02-17T10:00:00Z") + } +} + +func TestBuildCodexUsageExtraUpdates_FallbackToNowWhenUpdatedAtInvalid(t *testing.T) { + primaryUsed := 15.0 + primaryReset := 30 + primaryWindow := 300 + + fallbackNow := time.Date(2026, 2, 20, 8, 30, 0, 0, time.UTC) + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: &primaryUsed, + PrimaryResetAfterSeconds: &primaryReset, + PrimaryWindowMinutes: &primaryWindow, + UpdatedAt: "invalid-time", + } + + updates := buildCodexUsageExtraUpdates(snapshot, fallbackNow) + if updates == nil { + t.Fatal("expected non-nil updates") + } + + if got := updates["codex_usage_updated_at"]; got != "2026-02-20T08:30:00Z" { + t.Fatalf("codex_usage_updated_at = %v, want %s", got, "2026-02-20T08:30:00Z") + } + if got := updates["codex_5h_reset_at"]; got != "2026-02-20T08:30:30Z" { + t.Fatalf("codex_5h_reset_at = %v, want %s", got, "2026-02-20T08:30:30Z") + } +} + +func TestBuildCodexUsageExtraUpdates_ClampNegativeResetSeconds(t *testing.T) { + primaryUsed := 90.0 + primaryReset := 7200 + primaryWindow := 10080 + secondaryUsed := 100.0 + secondaryReset := -15 + secondaryWindow := 300 + + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: &primaryUsed, + PrimaryResetAfterSeconds: &primaryReset, + PrimaryWindowMinutes: &primaryWindow, + SecondaryUsedPercent: &secondaryUsed, + SecondaryResetAfterSeconds: &secondaryReset, + SecondaryWindowMinutes: &secondaryWindow, + UpdatedAt: "2026-02-16T10:00:00Z", + } + + updates := buildCodexUsageExtraUpdates(snapshot, time.Time{}) + if updates == nil { + t.Fatal("expected non-nil updates") + } + + if got := updates["codex_5h_reset_after_seconds"]; got != -15 { + t.Fatalf("codex_5h_reset_after_seconds = %v, want %d", got, -15) + } + if got := updates["codex_5h_reset_at"]; got != "2026-02-16T10:00:00Z" { + t.Fatalf("codex_5h_reset_at = %v, want %s", got, "2026-02-16T10:00:00Z") + } +} + +func TestBuildCodexUsageExtraUpdates_NilSnapshot(t *testing.T) { + if got := buildCodexUsageExtraUpdates(nil, time.Now()); got != nil { + t.Fatalf("expected nil updates, got %v", got) + } +} + +func TestBuildCodexUsageExtraUpdates_WithoutNormalizedWindowFields(t *testing.T) { + primaryUsed := 42.0 + fallbackNow := time.Date(2026, 2, 20, 9, 15, 0, 0, time.UTC) + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: &primaryUsed, + UpdatedAt: "", + } + + updates := buildCodexUsageExtraUpdates(snapshot, fallbackNow) + if updates == nil { + t.Fatal("expected non-nil updates") + } + + if got := updates["codex_usage_updated_at"]; got != "2026-02-20T09:15:00Z" { + t.Fatalf("codex_usage_updated_at = %v, want %s", got, "2026-02-20T09:15:00Z") + } + if _, ok := updates["codex_5h_reset_at"]; ok { + t.Fatalf("did not expect codex_5h_reset_at in updates: %v", updates["codex_5h_reset_at"]) + } + if _, ok := updates["codex_7d_reset_at"]; ok { + t.Fatalf("did not expect codex_7d_reset_at in updates: %v", updates["codex_7d_reset_at"]) + } +} diff --git a/backend/internal/service/openai_gateway_service_hotpath_test.go b/backend/internal/service/openai_gateway_service_hotpath_test.go new file mode 100644 index 00000000..f73c06c5 --- /dev/null +++ b/backend/internal/service/openai_gateway_service_hotpath_test.go @@ -0,0 +1,141 @@ +package service + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestExtractOpenAIRequestMetaFromBody(t *testing.T) { + tests := []struct { + name string + body []byte + wantModel string + wantStream bool + wantPromptKey string + }{ + { + name: "完整字段", + body: []byte(`{"model":"gpt-5","stream":true,"prompt_cache_key":" ses-1 "}`), + wantModel: "gpt-5", + wantStream: true, + wantPromptKey: "ses-1", + }, + { + name: "缺失可选字段", + body: []byte(`{"model":"gpt-4"}`), + wantModel: "gpt-4", + wantStream: false, + wantPromptKey: "", + }, + { + name: "空请求体", + body: nil, + wantModel: "", + wantStream: false, + wantPromptKey: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + model, stream, promptKey := extractOpenAIRequestMetaFromBody(tt.body) + require.Equal(t, tt.wantModel, model) + require.Equal(t, tt.wantStream, stream) + require.Equal(t, tt.wantPromptKey, promptKey) + }) + } +} + +func TestExtractOpenAIReasoningEffortFromBody(t *testing.T) { + tests := []struct { + name string + body []byte + model string + wantNil bool + wantValue string + }{ + { + name: "优先读取 reasoning.effort", + body: []byte(`{"reasoning":{"effort":"medium"}}`), + model: "gpt-5-high", + wantNil: false, + wantValue: "medium", + }, + { + name: "兼容 reasoning_effort", + body: []byte(`{"reasoning_effort":"x-high"}`), + model: "", + wantNil: false, + wantValue: "xhigh", + }, + { + name: "minimal 归一化为空", + body: []byte(`{"reasoning":{"effort":"minimal"}}`), + model: "gpt-5-high", + wantNil: true, + }, + { + name: "缺失字段时从模型后缀推导", + body: []byte(`{"input":"hi"}`), + model: "gpt-5-high", + wantNil: false, + wantValue: "high", + }, + { + name: "未知后缀不返回", + body: []byte(`{"input":"hi"}`), + model: "gpt-5-unknown", + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractOpenAIReasoningEffortFromBody(tt.body, tt.model) + if tt.wantNil { + require.Nil(t, got) + return + } + require.NotNil(t, got) + require.Equal(t, tt.wantValue, *got) + }) + } +} + +func TestGetOpenAIRequestBodyMap_UsesContextCache(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + cached := map[string]any{"model": "cached-model", "stream": true} + c.Set(OpenAIParsedRequestBodyKey, cached) + + got, err := getOpenAIRequestBodyMap(c, []byte(`{invalid-json`)) + require.NoError(t, err) + require.Equal(t, cached, got) +} + +func TestGetOpenAIRequestBodyMap_ParseErrorWithoutCache(t *testing.T) { + _, err := getOpenAIRequestBodyMap(nil, []byte(`{invalid-json`)) + require.Error(t, err) + require.Contains(t, err.Error(), "parse request") +} + +func TestGetOpenAIRequestBodyMap_WriteBackContextCache(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + got, err := getOpenAIRequestBodyMap(c, []byte(`{"model":"gpt-5","stream":true}`)) + require.NoError(t, err) + require.Equal(t, "gpt-5", got["model"]) + + cached, ok := c.Get(OpenAIParsedRequestBodyKey) + require.True(t, ok) + cachedMap, ok := cached.(map[string]any) + require.True(t, ok) + require.Equal(t, got, cachedMap) +} diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index ae69a986..4f5f7f3c 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "errors" + "fmt" "io" "net/http" "net/http/httptest" @@ -13,9 +14,15 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/cespare/xxhash/v2" "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" ) +// 编译期接口断言 +var _ AccountRepository = (*stubOpenAIAccountRepo)(nil) +var _ GatewayCache = (*stubGatewayCache)(nil) + type stubOpenAIAccountRepo struct { AccountRepository accounts []Account @@ -50,6 +57,10 @@ func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, pl return result, nil } +func (r stubOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + return r.ListSchedulableByPlatform(ctx, platform) +} + type stubConcurrencyCache struct { ConcurrencyCache loadBatchErr error @@ -124,17 +135,19 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) { svc := &OpenAIGatewayService{} + bodyWithKey := []byte(`{"prompt_cache_key":"ses_aaa"}`) + // 1) session_id header wins c.Request.Header.Set("session_id", "sess-123") c.Request.Header.Set("conversation_id", "conv-456") - h1 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"}) + h1 := svc.GenerateSessionHash(c, bodyWithKey) if h1 == "" { t.Fatalf("expected non-empty hash") } // 2) conversation_id used when session_id absent c.Request.Header.Del("session_id") - h2 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"}) + h2 := svc.GenerateSessionHash(c, bodyWithKey) if h2 == "" { t.Fatalf("expected non-empty hash") } @@ -144,7 +157,7 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) { // 3) prompt_cache_key used when both headers absent c.Request.Header.Del("conversation_id") - h3 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"}) + h3 := svc.GenerateSessionHash(c, bodyWithKey) if h3 == "" { t.Fatalf("expected non-empty hash") } @@ -153,12 +166,60 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) { } // 4) empty when no signals - h4 := svc.GenerateSessionHash(c, map[string]any{}) + h4 := svc.GenerateSessionHash(c, []byte(`{}`)) if h4 != "" { t.Fatalf("expected empty hash when no signals") } } +func TestOpenAIGatewayService_GenerateSessionHash_UsesXXHash64(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + + c.Request.Header.Set("session_id", "sess-fixed-value") + svc := &OpenAIGatewayService{} + + got := svc.GenerateSessionHash(c, nil) + want := fmt.Sprintf("%016x", xxhash.Sum64String("sess-fixed-value")) + require.Equal(t, want, got) +} + +func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + + c.Request.Header.Set("session_id", "sess-legacy-check") + svc := &OpenAIGatewayService{} + + sessionHash := svc.GenerateSessionHash(c, nil) + require.NotEmpty(t, sessionHash) + require.NotNil(t, c.Request) + require.NotNil(t, c.Request.Context()) + require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context())) +} + +func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + + svc := &OpenAIGatewayService{} + seed := "openai_ws_ingress:9:100:200" + + got := svc.GenerateSessionHashWithFallback(c, []byte(`{}`), seed) + want := fmt.Sprintf("%016x", xxhash.Sum64String(seed)) + require.Equal(t, want, got) + require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context())) + + empty := svc.GenerateSessionHashWithFallback(c, []byte(`{}`), " ") + require.Equal(t, "", empty) +} + func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { if c.waitCounts != nil { if count, ok := c.waitCounts[accountID]; ok { @@ -1066,6 +1127,43 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) { } } +func TestOpenAIStreamingReuseScannerBufferAndStillWorks(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"input_tokens_details\":{\"cached_tokens\":3}}}}\n\n")) + }() + + result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 1, result.usage.InputTokens) + require.Equal(t, 2, result.usage.OutputTokens) + require.Equal(t, 3, result.usage.CacheReadInputTokens) +} + func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ @@ -1149,3 +1247,332 @@ func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) { t.Fatalf("expected non-allowlisted host to fail") } } + +// ==================== P1-08 修复:model 替换性能优化测试 ==================== + +func TestReplaceModelInSSELine(t *testing.T) { + svc := &OpenAIGatewayService{} + + tests := []struct { + name string + line string + from string + to string + expected string + }{ + { + name: "顶层 model 字段替换", + line: `data: {"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`, + from: "gpt-4o", + to: "my-custom-model", + expected: `data: {"id":"chatcmpl-123","model":"my-custom-model","choices":[]}`, + }, + { + name: "嵌套 response.model 替换", + line: `data: {"type":"response","response":{"id":"resp-1","model":"gpt-4o","output":[]}}`, + from: "gpt-4o", + to: "my-model", + expected: `data: {"type":"response","response":{"id":"resp-1","model":"my-model","output":[]}}`, + }, + { + name: "model 不匹配时不替换", + line: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`, + from: "gpt-4o", + to: "my-model", + expected: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`, + }, + { + name: "无 model 字段时不替换", + line: `data: {"id":"chatcmpl-123","choices":[]}`, + from: "gpt-4o", + to: "my-model", + expected: `data: {"id":"chatcmpl-123","choices":[]}`, + }, + { + name: "空 data 行", + line: `data: `, + from: "gpt-4o", + to: "my-model", + expected: `data: `, + }, + { + name: "[DONE] 行", + line: `data: [DONE]`, + from: "gpt-4o", + to: "my-model", + expected: `data: [DONE]`, + }, + { + name: "非 data: 前缀行", + line: `event: message`, + from: "gpt-4o", + to: "my-model", + expected: `event: message`, + }, + { + name: "非法 JSON 不替换", + line: `data: {invalid json}`, + from: "gpt-4o", + to: "my-model", + expected: `data: {invalid json}`, + }, + { + name: "无空格 data: 格式", + line: `data:{"id":"x","model":"gpt-4o"}`, + from: "gpt-4o", + to: "my-model", + expected: `data: {"id":"x","model":"my-model"}`, + }, + { + name: "model 名含特殊字符", + line: `data: {"model":"org/model-v2.1-beta"}`, + from: "org/model-v2.1-beta", + to: "custom/alias", + expected: `data: {"model":"custom/alias"}`, + }, + { + name: "空行", + line: "", + from: "gpt-4o", + to: "my-model", + expected: "", + }, + { + name: "保持其他字段不变", + line: `data: {"id":"abc","object":"chat.completion.chunk","model":"gpt-4o","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`, + from: "gpt-4o", + to: "alias", + expected: `data: {"id":"abc","object":"chat.completion.chunk","model":"alias","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`, + }, + { + name: "顶层优先于嵌套:同时存在两个 model", + line: `data: {"model":"gpt-4o","response":{"model":"gpt-4o"}}`, + from: "gpt-4o", + to: "replaced", + expected: `data: {"model":"replaced","response":{"model":"gpt-4o"}}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.replaceModelInSSELine(tt.line, tt.from, tt.to) + require.Equal(t, tt.expected, got) + }) + } +} + +func TestReplaceModelInSSEBody(t *testing.T) { + svc := &OpenAIGatewayService{} + + tests := []struct { + name string + body string + from string + to string + expected string + }{ + { + name: "多行 SSE body 替换", + body: "data: {\"model\":\"gpt-4o\",\"choices\":[]}\n\ndata: {\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n", + from: "gpt-4o", + to: "alias", + expected: "data: {\"model\":\"alias\",\"choices\":[]}\n\ndata: {\"model\":\"alias\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n", + }, + { + name: "无需替换的 body", + body: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n", + from: "gpt-4o", + to: "alias", + expected: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n", + }, + { + name: "混合 event 和 data 行", + body: "event: message\ndata: {\"model\":\"gpt-4o\"}\n\n", + from: "gpt-4o", + to: "alias", + expected: "event: message\ndata: {\"model\":\"alias\"}\n\n", + }, + { + name: "空 body", + body: "", + from: "gpt-4o", + to: "alias", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.replaceModelInSSEBody(tt.body, tt.from, tt.to) + require.Equal(t, tt.expected, got) + }) + } +} + +func TestReplaceModelInResponseBody(t *testing.T) { + svc := &OpenAIGatewayService{} + + tests := []struct { + name string + body string + from string + to string + expected string + }{ + { + name: "替换顶层 model", + body: `{"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`, + from: "gpt-4o", + to: "alias", + expected: `{"id":"chatcmpl-123","model":"alias","choices":[]}`, + }, + { + name: "model 不匹配不替换", + body: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`, + from: "gpt-4o", + to: "alias", + expected: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`, + }, + { + name: "无 model 字段不替换", + body: `{"id":"chatcmpl-123","choices":[]}`, + from: "gpt-4o", + to: "alias", + expected: `{"id":"chatcmpl-123","choices":[]}`, + }, + { + name: "非法 JSON 返回原值", + body: `not json`, + from: "gpt-4o", + to: "alias", + expected: `not json`, + }, + { + name: "空 body 返回原值", + body: ``, + from: "gpt-4o", + to: "alias", + expected: ``, + }, + { + name: "保持嵌套结构不变", + body: `{"model":"gpt-4o","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`, + from: "gpt-4o", + to: "alias", + expected: `{"model":"alias","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.replaceModelInResponseBody([]byte(tt.body), tt.from, tt.to) + require.Equal(t, tt.expected, string(got)) + }) + } +} + +func TestExtractOpenAISSEDataLine(t *testing.T) { + tests := []struct { + name string + line string + wantData string + wantOK bool + }{ + {name: "标准格式", line: `data: {"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true}, + {name: "无空格格式", line: `data:{"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true}, + {name: "纯空数据", line: `data: `, wantData: ``, wantOK: true}, + {name: "非 data 行", line: `event: message`, wantData: ``, wantOK: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := extractOpenAISSEDataLine(tt.line) + require.Equal(t, tt.wantOK, ok) + require.Equal(t, tt.wantData, got) + }) + } +} + +func TestParseSSEUsage_SelectiveParsing(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 9, OutputTokens: 8, CacheReadInputTokens: 7} + + // 非 completed 事件,不应覆盖 usage + svc.parseSSEUsage(`{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2}}}`, usage) + require.Equal(t, 9, usage.InputTokens) + require.Equal(t, 8, usage.OutputTokens) + require.Equal(t, 7, usage.CacheReadInputTokens) + + // completed 事件,应提取 usage + svc.parseSSEUsage(`{"type":"response.completed","response":{"usage":{"input_tokens":3,"output_tokens":5,"input_tokens_details":{"cached_tokens":2}}}}`, usage) + require.Equal(t, 3, usage.InputTokens) + require.Equal(t, 5, usage.OutputTokens) + require.Equal(t, 2, usage.CacheReadInputTokens) +} + +func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) { + body := strings.Join([]string{ + `event: message`, + `data: {"type":"response.in_progress","response":{"id":"resp_1"}}`, + `data: {"type":"response.completed","response":{"id":"resp_1","model":"gpt-4o","usage":{"input_tokens":11,"output_tokens":22,"input_tokens_details":{"cached_tokens":3}}}}`, + `data: [DONE]`, + }, "\n") + + finalResp, ok := extractCodexFinalResponse(body) + require.True(t, ok) + require.Contains(t, string(finalResp), `"id":"resp_1"`) + require.Contains(t, string(finalResp), `"input_tokens":11`) +} + +func TestHandleOAuthSSEToJSON_CompletedEventReturnsJSON(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + svc := &OpenAIGatewayService{cfg: &config.Config{}} + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + } + body := []byte(strings.Join([]string{ + `data: {"type":"response.in_progress","response":{"id":"resp_2"}}`, + `data: {"type":"response.completed","response":{"id":"resp_2","model":"gpt-4o","usage":{"input_tokens":7,"output_tokens":9,"input_tokens_details":{"cached_tokens":1}}}}`, + `data: [DONE]`, + }, "\n")) + + usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o") + require.NoError(t, err) + require.NotNil(t, usage) + require.Equal(t, 7, usage.InputTokens) + require.Equal(t, 9, usage.OutputTokens) + require.Equal(t, 1, usage.CacheReadInputTokens) + // Header 可能由上游 Content-Type 透传;关键是 body 已转换为最终 JSON 响应。 + require.NotContains(t, rec.Body.String(), "event:") + require.Contains(t, rec.Body.String(), `"id":"resp_2"`) + require.NotContains(t, rec.Body.String(), "data:") +} + +func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + svc := &OpenAIGatewayService{cfg: &config.Config{}} + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + } + body := []byte(strings.Join([]string{ + `data: {"type":"response.in_progress","response":{"id":"resp_3"}}`, + `data: [DONE]`, + }, "\n")) + + usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o") + require.NoError(t, err) + require.NotNil(t, usage) + require.Equal(t, 0, usage.InputTokens) + require.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream") + require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`) +} diff --git a/backend/internal/service/openai_json_optimization_benchmark_test.go b/backend/internal/service/openai_json_optimization_benchmark_test.go new file mode 100644 index 00000000..1737804b --- /dev/null +++ b/backend/internal/service/openai_json_optimization_benchmark_test.go @@ -0,0 +1,357 @@ +package service + +import ( + "encoding/json" + "strconv" + "strings" + "testing" + + "github.com/tidwall/gjson" +) + +var ( + benchmarkToolContinuationBoolSink bool + benchmarkWSParseStringSink string + benchmarkWSParseMapSink map[string]any + benchmarkUsageSink OpenAIUsage +) + +func BenchmarkToolContinuationValidationLegacy(b *testing.B) { + reqBody := benchmarkToolContinuationRequestBody() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkToolContinuationBoolSink = legacyValidateFunctionCallOutputContext(reqBody) + } +} + +func BenchmarkToolContinuationValidationOptimized(b *testing.B) { + reqBody := benchmarkToolContinuationRequestBody() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkToolContinuationBoolSink = optimizedValidateFunctionCallOutputContext(reqBody) + } +} + +func BenchmarkWSIngressPayloadParseLegacy(b *testing.B) { + raw := benchmarkWSIngressPayloadBytes() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + eventType, model, promptCacheKey, previousResponseID, payload, err := legacyParseWSIngressPayload(raw) + if err == nil { + benchmarkWSParseStringSink = eventType + model + promptCacheKey + previousResponseID + benchmarkWSParseMapSink = payload + } + } +} + +func BenchmarkWSIngressPayloadParseOptimized(b *testing.B) { + raw := benchmarkWSIngressPayloadBytes() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + eventType, model, promptCacheKey, previousResponseID, payload, err := optimizedParseWSIngressPayload(raw) + if err == nil { + benchmarkWSParseStringSink = eventType + model + promptCacheKey + previousResponseID + benchmarkWSParseMapSink = payload + } + } +} + +func BenchmarkOpenAIUsageExtractLegacy(b *testing.B) { + body := benchmarkOpenAIUsageJSONBytes() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + usage, ok := legacyExtractOpenAIUsageFromJSONBytes(body) + if ok { + benchmarkUsageSink = usage + } + } +} + +func BenchmarkOpenAIUsageExtractOptimized(b *testing.B) { + body := benchmarkOpenAIUsageJSONBytes() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + usage, ok := extractOpenAIUsageFromJSONBytes(body) + if ok { + benchmarkUsageSink = usage + } + } +} + +func benchmarkToolContinuationRequestBody() map[string]any { + input := make([]any, 0, 64) + for i := 0; i < 24; i++ { + input = append(input, map[string]any{ + "type": "text", + "text": "benchmark text", + }) + } + for i := 0; i < 10; i++ { + callID := "call_" + strconv.Itoa(i) + input = append(input, map[string]any{ + "type": "tool_call", + "call_id": callID, + }) + input = append(input, map[string]any{ + "type": "function_call_output", + "call_id": callID, + }) + input = append(input, map[string]any{ + "type": "item_reference", + "id": callID, + }) + } + return map[string]any{ + "model": "gpt-5.3-codex", + "input": input, + } +} + +func benchmarkWSIngressPayloadBytes() []byte { + return []byte(`{"type":"response.create","model":"gpt-5.3-codex","prompt_cache_key":"cache_bench","previous_response_id":"resp_prev_bench","input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]}]}`) +} + +func benchmarkOpenAIUsageJSONBytes() []byte { + return []byte(`{"id":"resp_bench","object":"response","model":"gpt-5.3-codex","usage":{"input_tokens":3210,"output_tokens":987,"input_tokens_details":{"cached_tokens":456}}}`) +} + +func legacyValidateFunctionCallOutputContext(reqBody map[string]any) bool { + if !legacyHasFunctionCallOutput(reqBody) { + return true + } + previousResponseID, _ := reqBody["previous_response_id"].(string) + if strings.TrimSpace(previousResponseID) != "" { + return true + } + if legacyHasToolCallContext(reqBody) { + return true + } + if legacyHasFunctionCallOutputMissingCallID(reqBody) { + return false + } + callIDs := legacyFunctionCallOutputCallIDs(reqBody) + return legacyHasItemReferenceForCallIDs(reqBody, callIDs) +} + +func optimizedValidateFunctionCallOutputContext(reqBody map[string]any) bool { + validation := ValidateFunctionCallOutputContext(reqBody) + if !validation.HasFunctionCallOutput { + return true + } + previousResponseID, _ := reqBody["previous_response_id"].(string) + if strings.TrimSpace(previousResponseID) != "" { + return true + } + if validation.HasToolCallContext { + return true + } + if validation.HasFunctionCallOutputMissingCallID { + return false + } + return validation.HasItemReferenceForAllCallIDs +} + +func legacyHasFunctionCallOutput(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType == "function_call_output" { + return true + } + } + return false +} + +func legacyHasToolCallContext(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "tool_call" && itemType != "function_call" { + continue + } + if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" { + return true + } + } + return false +} + +func legacyFunctionCallOutputCallIDs(reqBody map[string]any) []string { + if reqBody == nil { + return nil + } + input, ok := reqBody["input"].([]any) + if !ok { + return nil + } + ids := make(map[string]struct{}) + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "function_call_output" { + continue + } + if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" { + ids[callID] = struct{}{} + } + } + if len(ids) == 0 { + return nil + } + callIDs := make([]string, 0, len(ids)) + for id := range ids { + callIDs = append(callIDs, id) + } + return callIDs +} + +func legacyHasFunctionCallOutputMissingCallID(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "function_call_output" { + continue + } + callID, _ := itemMap["call_id"].(string) + if strings.TrimSpace(callID) == "" { + return true + } + } + return false +} + +func legacyHasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool { + if reqBody == nil || len(callIDs) == 0 { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + referenceIDs := make(map[string]struct{}) + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "item_reference" { + continue + } + idValue, _ := itemMap["id"].(string) + idValue = strings.TrimSpace(idValue) + if idValue == "" { + continue + } + referenceIDs[idValue] = struct{}{} + } + if len(referenceIDs) == 0 { + return false + } + for _, callID := range callIDs { + if _, ok := referenceIDs[callID]; !ok { + return false + } + } + return true +} + +func legacyParseWSIngressPayload(raw []byte) (eventType, model, promptCacheKey, previousResponseID string, payload map[string]any, err error) { + values := gjson.GetManyBytes(raw, "type", "model", "prompt_cache_key", "previous_response_id") + eventType = strings.TrimSpace(values[0].String()) + if eventType == "" { + eventType = "response.create" + } + model = strings.TrimSpace(values[1].String()) + promptCacheKey = strings.TrimSpace(values[2].String()) + previousResponseID = strings.TrimSpace(values[3].String()) + payload = make(map[string]any) + if err = json.Unmarshal(raw, &payload); err != nil { + return "", "", "", "", nil, err + } + if _, exists := payload["type"]; !exists { + payload["type"] = "response.create" + } + return eventType, model, promptCacheKey, previousResponseID, payload, nil +} + +func optimizedParseWSIngressPayload(raw []byte) (eventType, model, promptCacheKey, previousResponseID string, payload map[string]any, err error) { + payload = make(map[string]any) + if err = json.Unmarshal(raw, &payload); err != nil { + return "", "", "", "", nil, err + } + eventType = openAIWSPayloadString(payload, "type") + if eventType == "" { + eventType = "response.create" + payload["type"] = eventType + } + model = openAIWSPayloadString(payload, "model") + promptCacheKey = openAIWSPayloadString(payload, "prompt_cache_key") + previousResponseID = openAIWSPayloadString(payload, "previous_response_id") + return eventType, model, promptCacheKey, previousResponseID, payload, nil +} + +func legacyExtractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) { + var response struct { + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + InputTokenDetails struct { + CachedTokens int `json:"cached_tokens"` + } `json:"input_tokens_details"` + } `json:"usage"` + } + if err := json.Unmarshal(body, &response); err != nil { + return OpenAIUsage{}, false + } + return OpenAIUsage{ + InputTokens: response.Usage.InputTokens, + OutputTokens: response.Usage.OutputTokens, + CacheReadInputTokens: response.Usage.InputTokenDetails.CachedTokens, + }, true +} diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go new file mode 100644 index 00000000..0840d3b1 --- /dev/null +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -0,0 +1,928 @@ +package service + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func f64p(v float64) *float64 { return &v } + +type httpUpstreamRecorder struct { + lastReq *http.Request + lastBody []byte + + resp *http.Response + err error +} + +func (u *httpUpstreamRecorder) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + u.lastReq = req + if req != nil && req.Body != nil { + b, _ := io.ReadAll(req.Body) + u.lastBody = b + _ = req.Body.Close() + req.Body = io.NopCloser(bytes.NewReader(b)) + } + if u.err != nil { + return nil, u.err + } + return u.resp, nil +} + +func (u *httpUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { + return u.Do(req, proxyURL, accountID, accountConcurrency) +} + +var structuredLogCaptureMu sync.Mutex + +type inMemoryLogSink struct { + mu sync.Mutex + events []*logger.LogEvent +} + +func (s *inMemoryLogSink) WriteLogEvent(event *logger.LogEvent) { + if event == nil { + return + } + cloned := *event + if event.Fields != nil { + cloned.Fields = make(map[string]any, len(event.Fields)) + for k, v := range event.Fields { + cloned.Fields[k] = v + } + } + s.mu.Lock() + s.events = append(s.events, &cloned) + s.mu.Unlock() +} + +func (s *inMemoryLogSink) ContainsMessage(substr string) bool { + s.mu.Lock() + defer s.mu.Unlock() + for _, ev := range s.events { + if ev != nil && strings.Contains(ev.Message, substr) { + return true + } + } + return false +} + +func (s *inMemoryLogSink) ContainsMessageAtLevel(substr, level string) bool { + s.mu.Lock() + defer s.mu.Unlock() + wantLevel := strings.ToLower(strings.TrimSpace(level)) + for _, ev := range s.events { + if ev == nil { + continue + } + if strings.Contains(ev.Message, substr) && strings.ToLower(strings.TrimSpace(ev.Level)) == wantLevel { + return true + } + } + return false +} + +func (s *inMemoryLogSink) ContainsFieldValue(field, substr string) bool { + s.mu.Lock() + defer s.mu.Unlock() + for _, ev := range s.events { + if ev == nil || ev.Fields == nil { + continue + } + if v, ok := ev.Fields[field]; ok && strings.Contains(fmt.Sprint(v), substr) { + return true + } + } + return false +} + +func (s *inMemoryLogSink) ContainsField(field string) bool { + s.mu.Lock() + defer s.mu.Unlock() + for _, ev := range s.events { + if ev == nil || ev.Fields == nil { + continue + } + if _, ok := ev.Fields[field]; ok { + return true + } + } + return false +} + +func captureStructuredLog(t *testing.T) (*inMemoryLogSink, func()) { + t.Helper() + structuredLogCaptureMu.Lock() + + err := logger.Init(logger.InitOptions{ + Level: "debug", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: true, + ToFile: false, + }, + Sampling: logger.SamplingOptions{Enabled: false}, + }) + require.NoError(t, err) + + sink := &inMemoryLogSink{} + logger.SetSink(sink) + return sink, func() { + logger.SetSink(nil) + structuredLogCaptureMu.Unlock() + } +} + +func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormalized(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("Authorization", "Bearer inbound-should-not-forward") + c.Request.Header.Set("Cookie", "secret=1") + c.Request.Header.Set("X-Api-Key", "sk-inbound") + c.Request.Header.Set("X-Goog-Api-Key", "goog-inbound") + c.Request.Header.Set("Accept-Encoding", "gzip") + c.Request.Header.Set("Proxy-Authorization", "Basic abc") + c.Request.Header.Set("X-Test", "keep") + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`) + + upstreamSSE := strings.Join([]string{ + `data: {"type":"response.output_item.added","item":{"type":"tool_call","tool_calls":[{"function":{"name":"apply_patch"}}]}}`, + "", + "data: [DONE]", + "", + }, "\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(upstreamSSE)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + openAITokenProvider: &OpenAITokenProvider{ // minimal: will be bypassed by nil cache/service, but GetAccessToken uses provider only if non-nil + accountRepo: nil, + }, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + // Use the gateway method that reads token from credentials when provider is nil. + svc.openAITokenProvider = nil + + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + + // 1) 透传 OAuth 请求体与旧链路关键行为保持一致:store=false + stream=true。 + require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool()) + require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool()) + require.Equal(t, "local-test-instructions", strings.TrimSpace(gjson.GetBytes(upstream.lastBody, "instructions").String())) + // 其余关键字段保持原值。 + require.Equal(t, "gpt-5.2", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, "hi", gjson.GetBytes(upstream.lastBody, "input.0.text").String()) + + // 2) only auth is replaced; inbound auth/cookie are not forwarded + require.Equal(t, "Bearer oauth-token", upstream.lastReq.Header.Get("Authorization")) + require.Equal(t, "codex_cli_rs/0.1.0", upstream.lastReq.Header.Get("User-Agent")) + require.Empty(t, upstream.lastReq.Header.Get("Cookie")) + require.Empty(t, upstream.lastReq.Header.Get("X-Api-Key")) + require.Empty(t, upstream.lastReq.Header.Get("X-Goog-Api-Key")) + require.Empty(t, upstream.lastReq.Header.Get("Accept-Encoding")) + require.Empty(t, upstream.lastReq.Header.Get("Proxy-Authorization")) + require.Empty(t, upstream.lastReq.Header.Get("X-Test")) + + // 3) required OAuth headers are present + require.Equal(t, "chatgpt.com", upstream.lastReq.Host) + require.Equal(t, "chatgpt-acc", upstream.lastReq.Header.Get("chatgpt-account-id")) + + // 4) downstream SSE keeps tool name (no toolCorrector) + body := rec.Body.String() + require.Contains(t, body, "apply_patch") + require.NotContains(t, body, "\"name\":\"edit\"") +} + +func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown") + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("OpenAI-Beta", "responses=experimental") + + // Codex 模型且缺少 instructions,应在本地直接 403 拒绝,不触达上游。 + originalBody := []byte(`{"model":"gpt-5.1-codex-max","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.Error(t, err) + require.Nil(t, result) + require.Equal(t, http.StatusForbidden, rec.Code) + require.Contains(t, rec.Body.String(), "requires a non-empty instructions field") + require.Nil(t, upstream.lastReq) + + require.True(t, logSink.ContainsMessage("OpenAI passthrough 本地拦截:Codex 请求缺少有效 instructions")) + require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown")) + require.True(t, logSink.ContainsFieldValue("reject_reason", "instructions_missing")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + + // store=true + stream=false should be forced to store=false + stream=true by applyCodexOAuthTransform (OAuth legacy path) + inputBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": false}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, inputBody) + require.NoError(t, err) + + // legacy path rewrites request body (not byte-equal) + require.NotEqual(t, inputBody, upstream.lastBody) + require.Contains(t, string(upstream.lastBody), `"store":false`) + require.Contains(t, string(upstream.lastBody), `"stream":true`) +} + +func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + // 复合 UA(前缀不是 codex_cli_rs),历史实现会误判为非 Codex 并走 opencode。 + c.Request.Header.Set("User-Agent", "Mozilla/5.0 codex_cli_rs/0.1.0") + + inputBody := []byte(`{"model":"gpt-5.2","stream":true,"store":false,"input":[{"type":"text","text":"hi"}]}`) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": false}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, inputBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Equal(t, "codex_cli_rs", upstream.lastReq.Header.Get("originator")) + require.NotEqual(t, "opencode", upstream.lastReq.Header.Get("originator")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + + headers := make(http.Header) + headers.Set("Content-Type", "application/json") + headers.Set("x-request-id", "rid") + headers.Set("x-codex-primary-used-percent", "12") + headers.Set("x-codex-secondary-used-percent", "34") + headers.Set("x-codex-primary-window-minutes", "300") + headers.Set("x-codex-secondary-window-minutes", "10080") + headers.Set("x-codex-primary-reset-after-seconds", "1") + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: headers, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + + require.Equal(t, "12", rec.Header().Get("x-codex-primary-used-percent")) + require.Equal(t, "34", rec.Header().Get("x-codex-secondary-used-percent")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughFlag(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + + resp := &http.Response{ + StatusCode: http.StatusBadRequest, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(`{"error":{"message":"bad"}}`)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.Error(t, err) + + // should append an upstream error event with passthrough=true + v, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + arr, ok := v.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.NotEmpty(t, arr) + require.True(t, arr[len(arr)-1].Passthrough) +} + +func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + // Non-Codex UA + c.Request.Header.Set("User-Agent", "curl/8.0") + + inputBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, inputBody) + require.NoError(t, err) + require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool()) + require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool()) + require.Equal(t, "codex_cli_rs/0.104.0", upstream.lastReq.Header.Get("User-Agent")) +} + +func TestOpenAIGatewayService_CodexCLIOnly_RejectsNonCodexClient(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "curl/8.0") + + inputBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`) + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true, "codex_cli_only": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, inputBody) + require.Error(t, err) + require.Equal(t, http.StatusForbidden, rec.Code) + require.Contains(t, rec.Body.String(), "Codex official clients") +} + +func TestOpenAIGatewayService_CodexCLIOnly_AllowOfficialClientFamilies(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + ua string + originator string + }{ + {name: "codex_cli_rs", ua: "codex_cli_rs/0.99.0", originator: ""}, + {name: "codex_vscode", ua: "codex_vscode/1.0.0", originator: ""}, + {name: "codex_app", ua: "codex_app/2.1.0", originator: ""}, + {name: "originator_codex_chatgpt_desktop", ua: "curl/8.0", originator: "codex_chatgpt_desktop"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", tt.ua) + if tt.originator != "" { + c.Request.Header.Set("originator", tt.originator) + } + + inputBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true, "codex_cli_only": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, inputBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + }) + } +} + +func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) + + upstreamSSE := strings.Join([]string{ + `data: {"type":"response.output_text.delta","delta":"h"}`, + "", + "data: [DONE]", + "", + }, "\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(upstreamSSE)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + start := time.Now() + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + // sanity: duration after start + require.GreaterOrEqual(t, time.Since(start), time.Duration(0)) + require.NotNil(t, result.FirstTokenMs) + require.GreaterOrEqual(t, *result.FirstTokenMs, 0) +} + +func TestOpenAIGatewayService_OAuthPassthrough_StreamClientDisconnectStillCollectsUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + // 首次写入成功,后续写入失败,模拟客户端中途断开。 + c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 1} + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) + + upstreamSSE := strings.Join([]string{ + `data: {"type":"response.output_text.delta","delta":"h"}`, + "", + `data: {"type":"response.completed","response":{"usage":{"input_tokens":11,"output_tokens":7,"input_tokens_details":{"cached_tokens":3}}}}`, + "", + "data: [DONE]", + "", + }, "\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(upstreamSSE)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.NotNil(t, result.FirstTokenMs) + require.Equal(t, 11, result.Usage.InputTokens) + require.Equal(t, 7, result.Usage.OutputTokens) + require.Equal(t, 3, result.Usage.CacheReadInputTokens) +} + +func TestOpenAIGatewayService_APIKeyPassthrough_PreservesBodyAndUsesResponsesEndpoint(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "curl/8.0") + c.Request.Header.Set("X-Test", "keep") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"max_output_tokens":128,"input":[{"type":"text","text":"hi"}]}`) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 456, + Name: "apikey-acc", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{"api_key": "sk-api-key", "base_url": "https://api.openai.com"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Equal(t, originalBody, upstream.lastBody) + require.Equal(t, "https://api.openai.com/v1/responses", upstream.lastReq.URL.String()) + require.Equal(t, "Bearer sk-api-key", upstream.lastReq.Header.Get("Authorization")) + require.Equal(t, "curl/8.0", upstream.lastReq.Header.Get("User-Agent")) + require.Empty(t, upstream.lastReq.Header.Get("X-Test")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_WarnOnTimeoutHeadersForStream(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("x-stainless-timeout", "10000") + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "X-Request-Id": []string{"rid-timeout"}}, + Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")), + } + upstream := &httpUpstreamRecorder{resp: resp} + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 321, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.True(t, logSink.ContainsMessage("检测到超时相关请求头,将按配置过滤以降低断流风险")) + require.True(t, logSink.ContainsFieldValue("timeout_headers", "x-stainless-timeout=10000")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_InfoWhenStreamEndsWithoutDone(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) + // 注意:刻意不发送 [DONE],模拟上游中途断流。 + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "X-Request-Id": []string{"rid-truncate"}}, + Body: io.NopCloser(strings.NewReader("data: {\"type\":\"response.output_text.delta\",\"delta\":\"h\"}\n\n")), + } + upstream := &httpUpstreamRecorder{resp: resp} + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 654, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.True(t, logSink.ContainsMessage("上游流在未收到 [DONE] 时结束,疑似断流")) + require.True(t, logSink.ContainsMessageAtLevel("上游流在未收到 [DONE] 时结束,疑似断流", "info")) + require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-truncate")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_DefaultFiltersTimeoutHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("x-stainless-timeout", "120000") + c.Request.Header.Set("X-Test", "keep") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "X-Request-Id": []string{"rid-filter-default"}}, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + } + upstream := &httpUpstreamRecorder{resp: resp} + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 111, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Empty(t, upstream.lastReq.Header.Get("x-stainless-timeout")) + require.Empty(t, upstream.lastReq.Header.Get("X-Test")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_AllowTimeoutHeadersWhenConfigured(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("x-stainless-timeout", "120000") + c.Request.Header.Set("X-Test", "keep") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "X-Request-Id": []string{"rid-filter-allow"}}, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + } + upstream := &httpUpstreamRecorder{resp: resp} + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ + ForceCodexCLI: false, + OpenAIPassthroughAllowTimeoutHeaders: true, + }}, + httpUpstream: upstream, + } + account := &Account{ + ID: 222, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Equal(t, "120000", upstream.lastReq.Header.Get("x-stainless-timeout")) + require.Empty(t, upstream.lastReq.Header.Get("X-Test")) +} diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go index ca7470b9..72f4bbb0 100644 --- a/backend/internal/service/openai_oauth_service.go +++ b/backend/internal/service/openai_oauth_service.go @@ -2,13 +2,31 @@ package service import ( "context" + "crypto/subtle" + "encoding/json" + "io" + "log/slog" "net/http" + "regexp" + "sort" + "strconv" + "strings" "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" ) +var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session" + +var soraSessionCookiePattern = regexp.MustCompile(`(?i)(?:^|[\n\r;])\s*(?:(?:set-cookie|cookie)\s*:\s*)?__Secure-(?:next-auth|authjs)\.session-token(?:\.(\d+))?=([^;\r\n]+)`) + +type soraSessionChunk struct { + index int + value string +} + // OpenAIOAuthService handles OpenAI OAuth authentication flows type OpenAIOAuthService struct { sessionStore *openai.SessionStore @@ -32,7 +50,7 @@ type OpenAIAuthURLResult struct { } // GenerateAuthURL generates an OpenAI OAuth authorization URL -func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI string) (*OpenAIAuthURLResult, error) { +func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, platform string) (*OpenAIAuthURLResult, error) { // Generate PKCE values state, err := openai.GenerateState() if err != nil { @@ -68,11 +86,14 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 if redirectURI == "" { redirectURI = openai.DefaultRedirectURI } + normalizedPlatform := normalizeOpenAIOAuthPlatform(platform) + clientID, _ := openai.OAuthClientConfigByPlatform(normalizedPlatform) // Store session session := &openai.OAuthSession{ State: state, CodeVerifier: codeVerifier, + ClientID: clientID, RedirectURI: redirectURI, ProxyURL: proxyURL, CreatedAt: time.Now(), @@ -80,7 +101,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 s.sessionStore.Set(sessionID, session) // Build authorization URL - authURL := openai.BuildAuthorizationURL(state, codeChallenge, redirectURI) + authURL := openai.BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, normalizedPlatform) return &OpenAIAuthURLResult{ AuthURL: authURL, @@ -92,6 +113,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 type OpenAIExchangeCodeInput struct { SessionID string Code string + State string RedirectURI string ProxyID *int64 } @@ -103,6 +125,7 @@ type OpenAITokenInfo struct { IDToken string `json:"id_token,omitempty"` ExpiresIn int64 `json:"expires_in"` ExpiresAt int64 `json:"expires_at"` + ClientID string `json:"client_id,omitempty"` Email string `json:"email,omitempty"` ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"` ChatGPTUserID string `json:"chatgpt_user_id,omitempty"` @@ -116,6 +139,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch if !ok { return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired") } + if input.State == "" { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_STATE_REQUIRED", "oauth state is required") + } + if subtle.ConstantTimeCompare([]byte(input.State), []byte(session.State)) != 1 { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_STATE", "invalid oauth state") + } // Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL proxyURL := session.ProxyURL @@ -134,9 +163,13 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch if input.RedirectURI != "" { redirectURI = input.RedirectURI } + clientID := strings.TrimSpace(session.ClientID) + if clientID == "" { + clientID = openai.ClientID + } // Exchange code for token - tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL) + tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL, clientID) if err != nil { return nil, err } @@ -144,8 +177,10 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch // Parse ID token to get user info var userInfo *openai.UserInfo if tokenResp.IDToken != "" { - claims, err := openai.ParseIDToken(tokenResp.IDToken) - if err == nil { + claims, parseErr := openai.ParseIDToken(tokenResp.IDToken) + if parseErr != nil { + slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr) + } else { userInfo = claims.GetUserInfo() } } @@ -159,6 +194,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch IDToken: tokenResp.IDToken, ExpiresIn: int64(tokenResp.ExpiresIn), ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn), + ClientID: clientID, } if userInfo != nil { @@ -173,7 +209,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch // RefreshToken refreshes an OpenAI OAuth token func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) { - tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL) + return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "") +} + +// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id. +func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) { + tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) if err != nil { return nil, err } @@ -181,8 +222,10 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri // Parse ID token to get user info var userInfo *openai.UserInfo if tokenResp.IDToken != "" { - claims, err := openai.ParseIDToken(tokenResp.IDToken) - if err == nil { + claims, parseErr := openai.ParseIDToken(tokenResp.IDToken) + if parseErr != nil { + slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr) + } else { userInfo = claims.GetUserInfo() } } @@ -194,6 +237,9 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri ExpiresIn: int64(tokenResp.ExpiresIn), ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn), } + if trimmed := strings.TrimSpace(clientID); trimmed != "" { + tokenInfo.ClientID = trimmed + } if userInfo != nil { tokenInfo.Email = userInfo.Email @@ -205,13 +251,221 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri return tokenInfo, nil } -// RefreshAccountToken refreshes token for an OpenAI account -func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { - if !account.IsOpenAI() { - return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account") +// ExchangeSoraSessionToken exchanges Sora session_token to access_token. +func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) { + sessionToken = normalizeSoraSessionTokenInput(sessionToken) + if strings.TrimSpace(sessionToken) == "" { + return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required") } - refreshToken := account.GetOpenAIRefreshToken() + proxyURL, err := s.resolveProxyURL(ctx, proxyID) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil) + if err != nil { + return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err) + } + req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken)) + req.Header.Set("Accept", "application/json") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") + req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + + client, err := httpclient.GetClient(httpclient.Options{ + ProxyURL: proxyURL, + Timeout: 120 * time.Second, + }) + if err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_CLIENT_FAILED", "create http client failed: %v", err) + } + resp, err := client.Do(req) + if err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if resp.StatusCode != http.StatusOK { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + var sessionResp struct { + AccessToken string `json:"accessToken"` + Expires string `json:"expires"` + User struct { + Email string `json:"email"` + Name string `json:"name"` + } `json:"user"` + } + if err := json.Unmarshal(body, &sessionResp); err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err) + } + if strings.TrimSpace(sessionResp.AccessToken) == "" { + return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token") + } + + expiresAt := time.Now().Add(time.Hour).Unix() + if strings.TrimSpace(sessionResp.Expires) != "" { + if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil { + expiresAt = parsed.Unix() + } + } + expiresIn := expiresAt - time.Now().Unix() + if expiresIn < 0 { + expiresIn = 0 + } + + return &OpenAITokenInfo{ + AccessToken: strings.TrimSpace(sessionResp.AccessToken), + ExpiresIn: expiresIn, + ExpiresAt: expiresAt, + ClientID: openai.SoraClientID, + Email: strings.TrimSpace(sessionResp.User.Email), + }, nil +} + +func normalizeSoraSessionTokenInput(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + + matches := soraSessionCookiePattern.FindAllStringSubmatch(trimmed, -1) + if len(matches) == 0 { + return sanitizeSessionToken(trimmed) + } + + chunkMatches := make([]soraSessionChunk, 0, len(matches)) + singleValues := make([]string, 0, len(matches)) + + for _, match := range matches { + if len(match) < 3 { + continue + } + + value := sanitizeSessionToken(match[2]) + if value == "" { + continue + } + + if strings.TrimSpace(match[1]) == "" { + singleValues = append(singleValues, value) + continue + } + + idx, err := strconv.Atoi(strings.TrimSpace(match[1])) + if err != nil || idx < 0 { + continue + } + chunkMatches = append(chunkMatches, soraSessionChunk{ + index: idx, + value: value, + }) + } + + if merged := mergeLatestSoraSessionChunks(chunkMatches); merged != "" { + return merged + } + + if len(singleValues) > 0 { + return singleValues[len(singleValues)-1] + } + + return "" +} + +func mergeSoraSessionChunkSegment(chunks []soraSessionChunk, requiredMaxIndex int, requireComplete bool) string { + if len(chunks) == 0 { + return "" + } + + byIndex := make(map[int]string, len(chunks)) + for _, chunk := range chunks { + byIndex[chunk.index] = chunk.value + } + + if _, ok := byIndex[0]; !ok { + return "" + } + if requireComplete { + for idx := 0; idx <= requiredMaxIndex; idx++ { + if _, ok := byIndex[idx]; !ok { + return "" + } + } + } + + orderedIndexes := make([]int, 0, len(byIndex)) + for idx := range byIndex { + orderedIndexes = append(orderedIndexes, idx) + } + sort.Ints(orderedIndexes) + + var builder strings.Builder + for _, idx := range orderedIndexes { + if _, err := builder.WriteString(byIndex[idx]); err != nil { + return "" + } + } + return sanitizeSessionToken(builder.String()) +} + +func mergeLatestSoraSessionChunks(chunks []soraSessionChunk) string { + if len(chunks) == 0 { + return "" + } + + requiredMaxIndex := 0 + for _, chunk := range chunks { + if chunk.index > requiredMaxIndex { + requiredMaxIndex = chunk.index + } + } + + groupStarts := make([]int, 0, len(chunks)) + for idx, chunk := range chunks { + if chunk.index == 0 { + groupStarts = append(groupStarts, idx) + } + } + + if len(groupStarts) == 0 { + return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false) + } + + for i := len(groupStarts) - 1; i >= 0; i-- { + start := groupStarts[i] + end := len(chunks) + if i+1 < len(groupStarts) { + end = groupStarts[i+1] + } + if merged := mergeSoraSessionChunkSegment(chunks[start:end], requiredMaxIndex, true); merged != "" { + return merged + } + } + + return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false) +} + +func sanitizeSessionToken(raw string) string { + token := strings.TrimSpace(raw) + token = strings.Trim(token, "\"'`") + token = strings.TrimSuffix(token, ";") + return strings.TrimSpace(token) +} + +// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account +func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { + if account.Platform != PlatformOpenAI && account.Platform != PlatformSora { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account") + } + if account.Type != AccountTypeOAuth { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account") + } + + refreshToken := account.GetCredential("refresh_token") if refreshToken == "" { return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available") } @@ -224,7 +478,8 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A } } - return s.RefreshToken(ctx, refreshToken, proxyURL) + clientID := account.GetCredential("client_id") + return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) } // BuildAccountCredentials builds credentials map from token info @@ -232,9 +487,12 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) expiresAt := time.Unix(tokenInfo.ExpiresAt, 0).Format(time.RFC3339) creds := map[string]any{ - "access_token": tokenInfo.AccessToken, - "refresh_token": tokenInfo.RefreshToken, - "expires_at": expiresAt, + "access_token": tokenInfo.AccessToken, + "expires_at": expiresAt, + } + // 仅在刷新响应返回了新的 refresh_token 时才更新,防止用空值覆盖已有令牌 + if strings.TrimSpace(tokenInfo.RefreshToken) != "" { + creds["refresh_token"] = tokenInfo.RefreshToken } if tokenInfo.IDToken != "" { @@ -252,6 +510,9 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) if tokenInfo.OrganizationID != "" { creds["organization_id"] = tokenInfo.OrganizationID } + if strings.TrimSpace(tokenInfo.ClientID) != "" { + creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID) + } return creds } @@ -260,3 +521,26 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) func (s *OpenAIOAuthService) Stop() { s.sessionStore.Stop() } + +func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) { + if proxyID == nil { + return "", nil + } + proxy, err := s.proxyRepo.GetByID(ctx, *proxyID) + if err != nil { + return "", infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err) + } + if proxy == nil { + return "", nil + } + return proxy.URL(), nil +} + +func normalizeOpenAIOAuthPlatform(platform string) string { + switch strings.ToLower(strings.TrimSpace(platform)) { + case PlatformSora: + return openai.OAuthPlatformSora + default: + return openai.OAuthPlatformOpenAI + } +} diff --git a/backend/internal/service/openai_oauth_service_auth_url_test.go b/backend/internal/service/openai_oauth_service_auth_url_test.go new file mode 100644 index 00000000..5f26903d --- /dev/null +++ b/backend/internal/service/openai_oauth_service_auth_url_test.go @@ -0,0 +1,67 @@ +package service + +import ( + "context" + "errors" + "net/url" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/stretchr/testify/require" +) + +type openaiOAuthClientAuthURLStub struct{} + +func (s *openaiOAuthClientAuthURLStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientAuthURLStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientAuthURLStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func TestOpenAIOAuthService_GenerateAuthURL_OpenAIKeepsCodexFlow(t *testing.T) { + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{}) + defer svc.Stop() + + result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformOpenAI) + require.NoError(t, err) + require.NotEmpty(t, result.AuthURL) + require.NotEmpty(t, result.SessionID) + + parsed, err := url.Parse(result.AuthURL) + require.NoError(t, err) + q := parsed.Query() + require.Equal(t, openai.ClientID, q.Get("client_id")) + require.Equal(t, "true", q.Get("codex_cli_simplified_flow")) + + session, ok := svc.sessionStore.Get(result.SessionID) + require.True(t, ok) + require.Equal(t, openai.ClientID, session.ClientID) +} + +// TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient 验证 Sora 平台复用 Codex CLI 的 +// client_id(支持 localhost redirect_uri),但不启用 codex_cli_simplified_flow。 +func TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient(t *testing.T) { + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{}) + defer svc.Stop() + + result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformSora) + require.NoError(t, err) + require.NotEmpty(t, result.AuthURL) + require.NotEmpty(t, result.SessionID) + + parsed, err := url.Parse(result.AuthURL) + require.NoError(t, err) + q := parsed.Query() + require.Equal(t, openai.ClientID, q.Get("client_id")) + require.Empty(t, q.Get("codex_cli_simplified_flow")) + + session, ok := svc.sessionStore.Get(result.SessionID) + require.True(t, ok) + require.Equal(t, openai.ClientID, session.ClientID) +} diff --git a/backend/internal/service/openai_oauth_service_sora_session_test.go b/backend/internal/service/openai_oauth_service_sora_session_test.go new file mode 100644 index 00000000..08da8557 --- /dev/null +++ b/backend/internal/service/openai_oauth_service_sora_session_test.go @@ -0,0 +1,173 @@ +package service + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/stretchr/testify/require" +) + +type openaiOAuthClientNoopStub struct{} + +func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil) + require.NoError(t, err) + require.NotNil(t, info) + require.Equal(t, "at-token", info.AccessToken) + require.Equal(t, "demo@example.com", info.Email) + require.Greater(t, info.ExpiresAt, int64(0)) +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + _, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "missing access token") +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_AcceptsSetCookieLine(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-cookie-value") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + raw := "__Secure-next-auth.session-token.0=st-cookie-value; Domain=.chatgpt.com; Path=/; HttpOnly; Secure; SameSite=Lax" + info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) + require.NoError(t, err) + require.Equal(t, "at-token", info.AccessToken) +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_MergesChunkedSetCookieLines(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=chunk-0chunk-1") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + raw := strings.Join([]string{ + "Set-Cookie: __Secure-next-auth.session-token.1=chunk-1; Path=/; HttpOnly", + "Set-Cookie: __Secure-next-auth.session-token.0=chunk-0; Path=/; HttpOnly", + }, "\n") + info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) + require.NoError(t, err) + require.Equal(t, "at-token", info.AccessToken) +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_PrefersLatestDuplicateChunks(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=new-0new-1") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + raw := strings.Join([]string{ + "Set-Cookie: __Secure-next-auth.session-token.0=old-0; Path=/; HttpOnly", + "Set-Cookie: __Secure-next-auth.session-token.1=old-1; Path=/; HttpOnly", + "Set-Cookie: __Secure-next-auth.session-token.0=new-0; Path=/; HttpOnly", + "Set-Cookie: __Secure-next-auth.session-token.1=new-1; Path=/; HttpOnly", + }, "\n") + info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) + require.NoError(t, err) + require.Equal(t, "at-token", info.AccessToken) +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_UsesLatestCompleteChunkGroup(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=ok-0ok-1") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + raw := strings.Join([]string{ + "set-cookie", + "__Secure-next-auth.session-token.0=ok-0; Domain=.chatgpt.com; Path=/", + "set-cookie", + "__Secure-next-auth.session-token.1=ok-1; Domain=.chatgpt.com; Path=/", + "set-cookie", + "__Secure-next-auth.session-token.0=partial-0; Domain=.chatgpt.com; Path=/", + }, "\n") + info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) + require.NoError(t, err) + require.Equal(t, "at-token", info.AccessToken) +} diff --git a/backend/internal/service/openai_oauth_service_state_test.go b/backend/internal/service/openai_oauth_service_state_test.go new file mode 100644 index 00000000..29252328 --- /dev/null +++ b/backend/internal/service/openai_oauth_service_state_test.go @@ -0,0 +1,106 @@ +package service + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/stretchr/testify/require" +) + +type openaiOAuthClientStateStub struct { + exchangeCalled int32 + lastClientID string +} + +func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { + atomic.AddInt32(&s.exchangeCalled, 1) + s.lastClientID = clientID + return &openai.TokenResponse{ + AccessToken: "at", + RefreshToken: "rt", + ExpiresIn: 3600, + }, nil +} + +func (s *openaiOAuthClientStateStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientStateStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + return s.RefreshToken(ctx, refreshToken, proxyURL) +} + +func TestOpenAIOAuthService_ExchangeCode_StateRequired(t *testing.T) { + client := &openaiOAuthClientStateStub{} + svc := NewOpenAIOAuthService(nil, client) + defer svc.Stop() + + svc.sessionStore.Set("sid", &openai.OAuthSession{ + State: "expected-state", + CodeVerifier: "verifier", + RedirectURI: openai.DefaultRedirectURI, + CreatedAt: time.Now(), + }) + + _, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{ + SessionID: "sid", + Code: "auth-code", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "oauth state is required") + require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled)) +} + +func TestOpenAIOAuthService_ExchangeCode_StateMismatch(t *testing.T) { + client := &openaiOAuthClientStateStub{} + svc := NewOpenAIOAuthService(nil, client) + defer svc.Stop() + + svc.sessionStore.Set("sid", &openai.OAuthSession{ + State: "expected-state", + CodeVerifier: "verifier", + RedirectURI: openai.DefaultRedirectURI, + CreatedAt: time.Now(), + }) + + _, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{ + SessionID: "sid", + Code: "auth-code", + State: "wrong-state", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid oauth state") + require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled)) +} + +func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) { + client := &openaiOAuthClientStateStub{} + svc := NewOpenAIOAuthService(nil, client) + defer svc.Stop() + + svc.sessionStore.Set("sid", &openai.OAuthSession{ + State: "expected-state", + CodeVerifier: "verifier", + RedirectURI: openai.DefaultRedirectURI, + CreatedAt: time.Now(), + }) + + info, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{ + SessionID: "sid", + Code: "auth-code", + State: "expected-state", + }) + require.NoError(t, err) + require.NotNil(t, info) + require.Equal(t, "at", info.AccessToken) + require.Equal(t, openai.ClientID, info.ClientID) + require.Equal(t, openai.ClientID, client.lastClientID) + require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled)) + + _, ok := svc.sessionStore.Get("sid") + require.False(t, ok) +} diff --git a/backend/internal/service/openai_previous_response_id.go b/backend/internal/service/openai_previous_response_id.go new file mode 100644 index 00000000..95865086 --- /dev/null +++ b/backend/internal/service/openai_previous_response_id.go @@ -0,0 +1,37 @@ +package service + +import ( + "regexp" + "strings" +) + +const ( + OpenAIPreviousResponseIDKindEmpty = "empty" + OpenAIPreviousResponseIDKindResponseID = "response_id" + OpenAIPreviousResponseIDKindMessageID = "message_id" + OpenAIPreviousResponseIDKindUnknown = "unknown" +) + +var ( + openAIResponseIDPattern = regexp.MustCompile(`^resp_[A-Za-z0-9_-]{1,256}$`) + openAIMessageIDPattern = regexp.MustCompile(`^(msg|message|item|chatcmpl)_[A-Za-z0-9_-]{1,256}$`) +) + +// ClassifyOpenAIPreviousResponseIDKind classifies previous_response_id to improve diagnostics. +func ClassifyOpenAIPreviousResponseIDKind(id string) string { + trimmed := strings.TrimSpace(id) + if trimmed == "" { + return OpenAIPreviousResponseIDKindEmpty + } + if openAIResponseIDPattern.MatchString(trimmed) { + return OpenAIPreviousResponseIDKindResponseID + } + if openAIMessageIDPattern.MatchString(strings.ToLower(trimmed)) { + return OpenAIPreviousResponseIDKindMessageID + } + return OpenAIPreviousResponseIDKindUnknown +} + +func IsOpenAIPreviousResponseIDLikelyMessageID(id string) bool { + return ClassifyOpenAIPreviousResponseIDKind(id) == OpenAIPreviousResponseIDKindMessageID +} diff --git a/backend/internal/service/openai_previous_response_id_test.go b/backend/internal/service/openai_previous_response_id_test.go new file mode 100644 index 00000000..7867b864 --- /dev/null +++ b/backend/internal/service/openai_previous_response_id_test.go @@ -0,0 +1,34 @@ +package service + +import "testing" + +func TestClassifyOpenAIPreviousResponseIDKind(t *testing.T) { + tests := []struct { + name string + id string + want string + }{ + {name: "empty", id: " ", want: OpenAIPreviousResponseIDKindEmpty}, + {name: "response_id", id: "resp_0906a621bc423a8d0169a108637ef88197b74b0e2f37ba358f", want: OpenAIPreviousResponseIDKindResponseID}, + {name: "message_id", id: "msg_123456", want: OpenAIPreviousResponseIDKindMessageID}, + {name: "item_id", id: "item_abcdef", want: OpenAIPreviousResponseIDKindMessageID}, + {name: "unknown", id: "foo_123456", want: OpenAIPreviousResponseIDKindUnknown}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := ClassifyOpenAIPreviousResponseIDKind(tc.id); got != tc.want { + t.Fatalf("ClassifyOpenAIPreviousResponseIDKind(%q)=%q want=%q", tc.id, got, tc.want) + } + }) + } +} + +func TestIsOpenAIPreviousResponseIDLikelyMessageID(t *testing.T) { + if !IsOpenAIPreviousResponseIDLikelyMessageID("msg_123") { + t.Fatal("expected msg_123 to be identified as message id") + } + if IsOpenAIPreviousResponseIDLikelyMessageID("resp_123") { + t.Fatal("expected resp_123 not to be identified as message id") + } +} diff --git a/backend/internal/service/openai_sticky_compat.go b/backend/internal/service/openai_sticky_compat.go new file mode 100644 index 00000000..e897debc --- /dev/null +++ b/backend/internal/service/openai_sticky_compat.go @@ -0,0 +1,214 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" + "sync/atomic" + "time" + + "github.com/cespare/xxhash/v2" + "github.com/gin-gonic/gin" +) + +type openAILegacySessionHashContextKey struct{} + +var openAILegacySessionHashKey = openAILegacySessionHashContextKey{} + +var ( + openAIStickyLegacyReadFallbackTotal atomic.Int64 + openAIStickyLegacyReadFallbackHit atomic.Int64 + openAIStickyLegacyDualWriteTotal atomic.Int64 +) + +func openAIStickyCompatStats() (legacyReadFallbackTotal, legacyReadFallbackHit, legacyDualWriteTotal int64) { + return openAIStickyLegacyReadFallbackTotal.Load(), + openAIStickyLegacyReadFallbackHit.Load(), + openAIStickyLegacyDualWriteTotal.Load() +} + +func deriveOpenAISessionHashes(sessionID string) (currentHash string, legacyHash string) { + normalized := strings.TrimSpace(sessionID) + if normalized == "" { + return "", "" + } + + currentHash = fmt.Sprintf("%016x", xxhash.Sum64String(normalized)) + sum := sha256.Sum256([]byte(normalized)) + legacyHash = hex.EncodeToString(sum[:]) + return currentHash, legacyHash +} + +func withOpenAILegacySessionHash(ctx context.Context, legacyHash string) context.Context { + if ctx == nil { + return nil + } + trimmed := strings.TrimSpace(legacyHash) + if trimmed == "" { + return ctx + } + return context.WithValue(ctx, openAILegacySessionHashKey, trimmed) +} + +func openAILegacySessionHashFromContext(ctx context.Context) string { + if ctx == nil { + return "" + } + value, _ := ctx.Value(openAILegacySessionHashKey).(string) + return strings.TrimSpace(value) +} + +func attachOpenAILegacySessionHashToGin(c *gin.Context, legacyHash string) { + if c == nil || c.Request == nil { + return + } + c.Request = c.Request.WithContext(withOpenAILegacySessionHash(c.Request.Context(), legacyHash)) +} + +func (s *OpenAIGatewayService) openAISessionHashReadOldFallbackEnabled() bool { + if s == nil || s.cfg == nil { + return true + } + return s.cfg.Gateway.OpenAIWS.SessionHashReadOldFallback +} + +func (s *OpenAIGatewayService) openAISessionHashDualWriteOldEnabled() bool { + if s == nil || s.cfg == nil { + return true + } + return s.cfg.Gateway.OpenAIWS.SessionHashDualWriteOld +} + +func (s *OpenAIGatewayService) openAISessionCacheKey(sessionHash string) string { + normalized := strings.TrimSpace(sessionHash) + if normalized == "" { + return "" + } + return "openai:" + normalized +} + +func (s *OpenAIGatewayService) openAILegacySessionCacheKey(ctx context.Context, sessionHash string) string { + legacyHash := openAILegacySessionHashFromContext(ctx) + if legacyHash == "" { + return "" + } + legacyKey := "openai:" + legacyHash + if legacyKey == s.openAISessionCacheKey(sessionHash) { + return "" + } + return legacyKey +} + +func (s *OpenAIGatewayService) openAIStickyLegacyTTL(ttl time.Duration) time.Duration { + legacyTTL := ttl + if legacyTTL <= 0 { + legacyTTL = openaiStickySessionTTL + } + if legacyTTL > 10*time.Minute { + return 10 * time.Minute + } + return legacyTTL +} + +func (s *OpenAIGatewayService) getStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) { + if s == nil || s.cache == nil { + return 0, nil + } + + primaryKey := s.openAISessionCacheKey(sessionHash) + if primaryKey == "" { + return 0, nil + } + + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), primaryKey) + if err == nil && accountID > 0 { + return accountID, nil + } + if !s.openAISessionHashReadOldFallbackEnabled() { + return accountID, err + } + + legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash) + if legacyKey == "" { + return accountID, err + } + + openAIStickyLegacyReadFallbackTotal.Add(1) + legacyAccountID, legacyErr := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), legacyKey) + if legacyErr == nil && legacyAccountID > 0 { + openAIStickyLegacyReadFallbackHit.Add(1) + return legacyAccountID, nil + } + return accountID, err +} + +func (s *OpenAIGatewayService) setStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string, accountID int64, ttl time.Duration) error { + if s == nil || s.cache == nil || accountID <= 0 { + return nil + } + primaryKey := s.openAISessionCacheKey(sessionHash) + if primaryKey == "" { + return nil + } + + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), primaryKey, accountID, ttl); err != nil { + return err + } + + if !s.openAISessionHashDualWriteOldEnabled() { + return nil + } + legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash) + if legacyKey == "" { + return nil + } + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), legacyKey, accountID, s.openAIStickyLegacyTTL(ttl)); err != nil { + return err + } + openAIStickyLegacyDualWriteTotal.Add(1) + return nil +} + +func (s *OpenAIGatewayService) refreshStickySessionTTL(ctx context.Context, groupID *int64, sessionHash string, ttl time.Duration) error { + if s == nil || s.cache == nil { + return nil + } + primaryKey := s.openAISessionCacheKey(sessionHash) + if primaryKey == "" { + return nil + } + + err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), primaryKey, ttl) + if !s.openAISessionHashReadOldFallbackEnabled() && !s.openAISessionHashDualWriteOldEnabled() { + return err + } + + legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash) + if legacyKey != "" { + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), legacyKey, s.openAIStickyLegacyTTL(ttl)) + } + return err +} + +func (s *OpenAIGatewayService) deleteStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string) error { + if s == nil || s.cache == nil { + return nil + } + primaryKey := s.openAISessionCacheKey(sessionHash) + if primaryKey == "" { + return nil + } + + err := s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), primaryKey) + if !s.openAISessionHashReadOldFallbackEnabled() && !s.openAISessionHashDualWriteOldEnabled() { + return err + } + + legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash) + if legacyKey != "" { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), legacyKey) + } + return err +} diff --git a/backend/internal/service/openai_sticky_compat_test.go b/backend/internal/service/openai_sticky_compat_test.go new file mode 100644 index 00000000..9f57c358 --- /dev/null +++ b/backend/internal/service/openai_sticky_compat_test.go @@ -0,0 +1,96 @@ +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func TestGetStickySessionAccountID_FallbackToLegacyKey(t *testing.T) { + beforeFallbackTotal, beforeFallbackHit, _ := openAIStickyCompatStats() + + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:legacy-hash": 42, + }, + } + svc := &OpenAIGatewayService{ + cache: cache, + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + OpenAIWS: config.GatewayOpenAIWSConfig{ + SessionHashReadOldFallback: true, + }, + }, + }, + } + + ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash") + accountID, err := svc.getStickySessionAccountID(ctx, nil, "new-hash") + require.NoError(t, err) + require.Equal(t, int64(42), accountID) + + afterFallbackTotal, afterFallbackHit, _ := openAIStickyCompatStats() + require.Equal(t, beforeFallbackTotal+1, afterFallbackTotal) + require.Equal(t, beforeFallbackHit+1, afterFallbackHit) +} + +func TestSetStickySessionAccountID_DualWriteOldEnabled(t *testing.T) { + _, _, beforeDualWriteTotal := openAIStickyCompatStats() + + cache := &stubGatewayCache{sessionBindings: map[string]int64{}} + svc := &OpenAIGatewayService{ + cache: cache, + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + OpenAIWS: config.GatewayOpenAIWSConfig{ + SessionHashDualWriteOld: true, + }, + }, + }, + } + + ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash") + err := svc.setStickySessionAccountID(ctx, nil, "new-hash", 9, openaiStickySessionTTL) + require.NoError(t, err) + require.Equal(t, int64(9), cache.sessionBindings["openai:new-hash"]) + require.Equal(t, int64(9), cache.sessionBindings["openai:legacy-hash"]) + + _, _, afterDualWriteTotal := openAIStickyCompatStats() + require.Equal(t, beforeDualWriteTotal+1, afterDualWriteTotal) +} + +func TestSetStickySessionAccountID_DualWriteOldDisabled(t *testing.T) { + cache := &stubGatewayCache{sessionBindings: map[string]int64{}} + svc := &OpenAIGatewayService{ + cache: cache, + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + OpenAIWS: config.GatewayOpenAIWSConfig{ + SessionHashDualWriteOld: false, + }, + }, + }, + } + + ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash") + err := svc.setStickySessionAccountID(ctx, nil, "new-hash", 9, openaiStickySessionTTL) + require.NoError(t, err) + require.Equal(t, int64(9), cache.sessionBindings["openai:new-hash"]) + _, exists := cache.sessionBindings["openai:legacy-hash"] + require.False(t, exists) +} + +func TestSnapshotOpenAICompatibilityFallbackMetrics(t *testing.T) { + before := SnapshotOpenAICompatibilityFallbackMetrics() + + ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, true) + _, _ = ThinkingEnabledFromContext(ctx) + + after := SnapshotOpenAICompatibilityFallbackMetrics() + require.GreaterOrEqual(t, after.MetadataLegacyFallbackTotal, before.MetadataLegacyFallbackTotal+1) + require.GreaterOrEqual(t, after.MetadataLegacyFallbackThinkingEnabledTotal, before.MetadataLegacyFallbackThinkingEnabledTotal+1) +} diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go index 87a7713b..a8a6b96c 100644 --- a/backend/internal/service/openai_token_provider.go +++ b/backend/internal/service/openai_token_provider.go @@ -4,16 +4,74 @@ import ( "context" "errors" "log/slog" + "math/rand/v2" "strings" + "sync/atomic" "time" ) const ( - openAITokenRefreshSkew = 3 * time.Minute - openAITokenCacheSkew = 5 * time.Minute - openAILockWaitTime = 200 * time.Millisecond + openAITokenRefreshSkew = 3 * time.Minute + openAITokenCacheSkew = 5 * time.Minute + openAILockInitialWait = 20 * time.Millisecond + openAILockMaxWait = 120 * time.Millisecond + openAILockMaxAttempts = 5 + openAILockJitterRatio = 0.2 + openAILockWarnThresholdMs = 250 ) +// OpenAITokenRuntimeMetrics 表示 OpenAI token 刷新与锁竞争保护指标快照。 +type OpenAITokenRuntimeMetrics struct { + RefreshRequests int64 + RefreshSuccess int64 + RefreshFailure int64 + LockAcquireFailure int64 + LockContention int64 + LockWaitSamples int64 + LockWaitTotalMs int64 + LockWaitHit int64 + LockWaitMiss int64 + LastObservedUnixMs int64 +} + +type openAITokenRuntimeMetricsStore struct { + refreshRequests atomic.Int64 + refreshSuccess atomic.Int64 + refreshFailure atomic.Int64 + lockAcquireFailure atomic.Int64 + lockContention atomic.Int64 + lockWaitSamples atomic.Int64 + lockWaitTotalMs atomic.Int64 + lockWaitHit atomic.Int64 + lockWaitMiss atomic.Int64 + lastObservedUnixMs atomic.Int64 +} + +func (m *openAITokenRuntimeMetricsStore) snapshot() OpenAITokenRuntimeMetrics { + if m == nil { + return OpenAITokenRuntimeMetrics{} + } + return OpenAITokenRuntimeMetrics{ + RefreshRequests: m.refreshRequests.Load(), + RefreshSuccess: m.refreshSuccess.Load(), + RefreshFailure: m.refreshFailure.Load(), + LockAcquireFailure: m.lockAcquireFailure.Load(), + LockContention: m.lockContention.Load(), + LockWaitSamples: m.lockWaitSamples.Load(), + LockWaitTotalMs: m.lockWaitTotalMs.Load(), + LockWaitHit: m.lockWaitHit.Load(), + LockWaitMiss: m.lockWaitMiss.Load(), + LastObservedUnixMs: m.lastObservedUnixMs.Load(), + } +} + +func (m *openAITokenRuntimeMetricsStore) touchNow() { + if m == nil { + return + } + m.lastObservedUnixMs.Store(time.Now().UnixMilli()) +} + // OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义) type OpenAITokenCache = GeminiTokenCache @@ -22,6 +80,7 @@ type OpenAITokenProvider struct { accountRepo AccountRepository tokenCache OpenAITokenCache openAIOAuthService *OpenAIOAuthService + metrics *openAITokenRuntimeMetricsStore } func NewOpenAITokenProvider( @@ -33,16 +92,32 @@ func NewOpenAITokenProvider( accountRepo: accountRepo, tokenCache: tokenCache, openAIOAuthService: openAIOAuthService, + metrics: &openAITokenRuntimeMetricsStore{}, + } +} + +func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics { + if p == nil { + return OpenAITokenRuntimeMetrics{} + } + p.ensureMetrics() + return p.metrics.snapshot() +} + +func (p *OpenAITokenProvider) ensureMetrics() { + if p != nil && p.metrics == nil { + p.metrics = &openAITokenRuntimeMetricsStore{} } } // GetAccessToken 获取有效的 access_token func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { + p.ensureMetrics() if account == nil { return "", errors.New("account is nil") } - if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth { - return "", errors.New("not an openai oauth account") + if (account.Platform != PlatformOpenAI && account.Platform != PlatformSora) || account.Type != AccountTypeOAuth { + return "", errors.New("not an openai/sora oauth account") } cacheKey := OpenAITokenCacheKey(account) @@ -64,6 +139,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew refreshFailed := false if needsRefresh && p.tokenCache != nil { + p.metrics.refreshRequests.Add(1) + p.metrics.touchNow() locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) if lockErr == nil && locked { defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() @@ -80,16 +157,23 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou } expiresAt = account.GetCredentialAsTime("expires_at") if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { - if p.openAIOAuthService == nil { + if account.Platform == PlatformSora { + slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID) + // Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。 + refreshFailed = true + } else if p.openAIOAuthService == nil { slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) + p.metrics.refreshFailure.Add(1) refreshFailed = true // 无法刷新,标记失败 } else { tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account) if err != nil { // 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err) + p.metrics.refreshFailure.Add(1) refreshFailed = true // 刷新失败,标记以使用短 TTL } else { + p.metrics.refreshSuccess.Add(1) newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo) for k, v := range account.Credentials { if _, exists := newCredentials[k]; !exists { @@ -106,6 +190,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou } } else if lockErr != nil { // Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时) + p.metrics.lockAcquireFailure.Add(1) + p.metrics.touchNow() slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr) // 检查 ctx 是否已取消 @@ -124,15 +210,22 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou // 仅在 expires_at 已过期/接近过期时才执行无锁刷新 if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { - if p.openAIOAuthService == nil { + if account.Platform == PlatformSora { + slog.Debug("openai_token_refresh_skipped_for_sora_degraded", "account_id", account.ID) + // Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。 + refreshFailed = true + } else if p.openAIOAuthService == nil { slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) + p.metrics.refreshFailure.Add(1) refreshFailed = true } else { tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account) if err != nil { slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err) + p.metrics.refreshFailure.Add(1) refreshFailed = true } else { + p.metrics.refreshSuccess.Add(1) newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo) for k, v := range account.Credentials { if _, exists := newCredentials[k]; !exists { @@ -148,16 +241,21 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou } } } else { - // 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存 - time.Sleep(openAILockWaitTime) - if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + // 锁被其他 worker 持有:使用短轮询+jitter,降低固定等待导致的尾延迟台阶。 + p.metrics.lockContention.Add(1) + p.metrics.touchNow() + token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey) + if waitErr != nil { + return "", waitErr + } + if strings.TrimSpace(token) != "" { slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID) return token, nil } } } - accessToken := account.GetOpenAIAccessToken() + accessToken := account.GetCredential("access_token") if strings.TrimSpace(accessToken) == "" { return "", errors.New("access_token not found in credentials") } @@ -198,3 +296,64 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou return accessToken, nil } + +func (p *OpenAITokenProvider) waitForTokenAfterLockRace(ctx context.Context, cacheKey string) (string, error) { + wait := openAILockInitialWait + totalWaitMs := int64(0) + for i := 0; i < openAILockMaxAttempts; i++ { + actualWait := jitterLockWait(wait) + timer := time.NewTimer(actualWait) + select { + case <-ctx.Done(): + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + return "", ctx.Err() + case <-timer.C: + } + + waitMs := actualWait.Milliseconds() + if waitMs < 0 { + waitMs = 0 + } + totalWaitMs += waitMs + p.metrics.lockWaitSamples.Add(1) + p.metrics.lockWaitTotalMs.Add(waitMs) + p.metrics.touchNow() + + token, err := p.tokenCache.GetAccessToken(ctx, cacheKey) + if err == nil && strings.TrimSpace(token) != "" { + p.metrics.lockWaitHit.Add(1) + if totalWaitMs >= openAILockWarnThresholdMs { + slog.Warn("openai_token_lock_wait_high", "wait_ms", totalWaitMs, "attempts", i+1) + } + return token, nil + } + + if wait < openAILockMaxWait { + wait *= 2 + if wait > openAILockMaxWait { + wait = openAILockMaxWait + } + } + } + + p.metrics.lockWaitMiss.Add(1) + if totalWaitMs >= openAILockWarnThresholdMs { + slog.Warn("openai_token_lock_wait_high", "wait_ms", totalWaitMs, "attempts", openAILockMaxAttempts) + } + return "", nil +} + +func jitterLockWait(base time.Duration) time.Duration { + if base <= 0 { + return 0 + } + minFactor := 1 - openAILockJitterRatio + maxFactor := 1 + openAILockJitterRatio + factor := minFactor + rand.Float64()*(maxFactor-minFactor) + return time.Duration(float64(base) * factor) +} diff --git a/backend/internal/service/openai_token_provider_test.go b/backend/internal/service/openai_token_provider_test.go index c2e3dbb0..1cd92367 100644 --- a/backend/internal/service/openai_token_provider_test.go +++ b/backend/internal/service/openai_token_provider_test.go @@ -375,7 +375,7 @@ func TestOpenAITokenProvider_WrongPlatform(t *testing.T) { token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) - require.Contains(t, err.Error(), "not an openai oauth account") + require.Contains(t, err.Error(), "not an openai/sora oauth account") require.Empty(t, token) } @@ -389,7 +389,7 @@ func TestOpenAITokenProvider_WrongAccountType(t *testing.T) { token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) - require.Contains(t, err.Error(), "not an openai oauth account") + require.Contains(t, err.Error(), "not an openai/sora oauth account") require.Empty(t, token) } @@ -808,3 +808,119 @@ func TestOpenAITokenProvider_Real_NilCredentials(t *testing.T) { require.Contains(t, err.Error(), "access_token not found") require.Empty(t, token) } + +func TestOpenAITokenProvider_Real_LockRace_PollingHitsCache(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.lockAcquired = false // 模拟锁被其他 worker 持有 + + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 207, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + + cacheKey := OpenAITokenCacheKey(account) + go func() { + time.Sleep(5 * time.Millisecond) + cache.mu.Lock() + cache.tokens[cacheKey] = "winner-token" + cache.mu.Unlock() + }() + + provider := NewOpenAITokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "winner-token", token) +} + +func TestOpenAITokenProvider_Real_LockRace_ContextCanceled(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.lockAcquired = false // 模拟锁被其他 worker 持有 + + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 208, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + provider := NewOpenAITokenProvider(nil, cache, nil) + start := time.Now() + token, err := provider.GetAccessToken(ctx, account) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + require.Empty(t, token) + require.Less(t, time.Since(start), 50*time.Millisecond) +} + +func TestOpenAITokenProvider_RuntimeMetrics_LockWaitHitAndSnapshot(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.lockAcquired = false + + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 209, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + cacheKey := OpenAITokenCacheKey(account) + go func() { + time.Sleep(10 * time.Millisecond) + cache.mu.Lock() + cache.tokens[cacheKey] = "winner-token" + cache.mu.Unlock() + }() + + provider := NewOpenAITokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "winner-token", token) + + metrics := provider.SnapshotRuntimeMetrics() + require.GreaterOrEqual(t, metrics.RefreshRequests, int64(1)) + require.GreaterOrEqual(t, metrics.LockContention, int64(1)) + require.GreaterOrEqual(t, metrics.LockWaitSamples, int64(1)) + require.GreaterOrEqual(t, metrics.LockWaitHit, int64(1)) + require.GreaterOrEqual(t, metrics.LockWaitTotalMs, int64(0)) + require.GreaterOrEqual(t, metrics.LastObservedUnixMs, int64(1)) +} + +func TestOpenAITokenProvider_RuntimeMetrics_LockAcquireFailure(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.lockErr = errors.New("redis lock error") + + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 210, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + + provider := NewOpenAITokenProvider(nil, cache, nil) + _, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + + metrics := provider.SnapshotRuntimeMetrics() + require.GreaterOrEqual(t, metrics.LockAcquireFailure, int64(1)) + require.GreaterOrEqual(t, metrics.RefreshRequests, int64(1)) +} diff --git a/backend/internal/service/openai_tool_continuation.go b/backend/internal/service/openai_tool_continuation.go index e59082b2..dea3c172 100644 --- a/backend/internal/service/openai_tool_continuation.go +++ b/backend/internal/service/openai_tool_continuation.go @@ -2,6 +2,24 @@ package service import "strings" +// ToolContinuationSignals 聚合工具续链相关信号,避免重复遍历 input。 +type ToolContinuationSignals struct { + HasFunctionCallOutput bool + HasFunctionCallOutputMissingCallID bool + HasToolCallContext bool + HasItemReference bool + HasItemReferenceForAllCallIDs bool + FunctionCallOutputCallIDs []string +} + +// FunctionCallOutputValidation 汇总 function_call_output 关联性校验结果。 +type FunctionCallOutputValidation struct { + HasFunctionCallOutput bool + HasToolCallContext bool + HasFunctionCallOutputMissingCallID bool + HasItemReferenceForAllCallIDs bool +} + // NeedsToolContinuation 判定请求是否需要工具调用续链处理。 // 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、 // 或显式声明 tools/tool_choice。 @@ -18,107 +36,191 @@ func NeedsToolContinuation(reqBody map[string]any) bool { if hasToolChoiceSignal(reqBody) { return true } - if inputHasType(reqBody, "function_call_output") { - return true + input, ok := reqBody["input"].([]any) + if !ok { + return false } - if inputHasType(reqBody, "item_reference") { - return true + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType == "function_call_output" || itemType == "item_reference" { + return true + } } return false } +// AnalyzeToolContinuationSignals 单次遍历 input,提取 function_call_output/tool_call/item_reference 相关信号。 +func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSignals { + signals := ToolContinuationSignals{} + if reqBody == nil { + return signals + } + input, ok := reqBody["input"].([]any) + if !ok { + return signals + } + + var callIDs map[string]struct{} + var referenceIDs map[string]struct{} + + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + switch itemType { + case "tool_call", "function_call": + callID, _ := itemMap["call_id"].(string) + if strings.TrimSpace(callID) != "" { + signals.HasToolCallContext = true + } + case "function_call_output": + signals.HasFunctionCallOutput = true + callID, _ := itemMap["call_id"].(string) + callID = strings.TrimSpace(callID) + if callID == "" { + signals.HasFunctionCallOutputMissingCallID = true + continue + } + if callIDs == nil { + callIDs = make(map[string]struct{}) + } + callIDs[callID] = struct{}{} + case "item_reference": + signals.HasItemReference = true + idValue, _ := itemMap["id"].(string) + idValue = strings.TrimSpace(idValue) + if idValue == "" { + continue + } + if referenceIDs == nil { + referenceIDs = make(map[string]struct{}) + } + referenceIDs[idValue] = struct{}{} + } + } + + if len(callIDs) == 0 { + return signals + } + signals.FunctionCallOutputCallIDs = make([]string, 0, len(callIDs)) + allReferenced := len(referenceIDs) > 0 + for callID := range callIDs { + signals.FunctionCallOutputCallIDs = append(signals.FunctionCallOutputCallIDs, callID) + if allReferenced { + if _, ok := referenceIDs[callID]; !ok { + allReferenced = false + } + } + } + signals.HasItemReferenceForAllCallIDs = allReferenced + return signals +} + +// ValidateFunctionCallOutputContext 为 handler 提供低开销校验结果: +// 1) 无 function_call_output 直接返回 +// 2) 若已存在 tool_call/function_call 上下文则提前返回 +// 3) 仅在无工具上下文时才构建 call_id / item_reference 集合 +func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutputValidation { + result := FunctionCallOutputValidation{} + if reqBody == nil { + return result + } + input, ok := reqBody["input"].([]any) + if !ok { + return result + } + + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + switch itemType { + case "function_call_output": + result.HasFunctionCallOutput = true + case "tool_call", "function_call": + callID, _ := itemMap["call_id"].(string) + if strings.TrimSpace(callID) != "" { + result.HasToolCallContext = true + } + } + if result.HasFunctionCallOutput && result.HasToolCallContext { + return result + } + } + + if !result.HasFunctionCallOutput || result.HasToolCallContext { + return result + } + + callIDs := make(map[string]struct{}) + referenceIDs := make(map[string]struct{}) + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + switch itemType { + case "function_call_output": + callID, _ := itemMap["call_id"].(string) + callID = strings.TrimSpace(callID) + if callID == "" { + result.HasFunctionCallOutputMissingCallID = true + continue + } + callIDs[callID] = struct{}{} + case "item_reference": + idValue, _ := itemMap["id"].(string) + idValue = strings.TrimSpace(idValue) + if idValue == "" { + continue + } + referenceIDs[idValue] = struct{}{} + } + } + + if len(callIDs) == 0 || len(referenceIDs) == 0 { + return result + } + allReferenced := true + for callID := range callIDs { + if _, ok := referenceIDs[callID]; !ok { + allReferenced = false + break + } + } + result.HasItemReferenceForAllCallIDs = allReferenced + return result +} + // HasFunctionCallOutput 判断 input 是否包含 function_call_output,用于触发续链校验。 func HasFunctionCallOutput(reqBody map[string]any) bool { - if reqBody == nil { - return false - } - return inputHasType(reqBody, "function_call_output") + return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutput } // HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call, // 用于判断 function_call_output 是否具备可关联的上下文。 func HasToolCallContext(reqBody map[string]any) bool { - if reqBody == nil { - return false - } - input, ok := reqBody["input"].([]any) - if !ok { - return false - } - for _, item := range input { - itemMap, ok := item.(map[string]any) - if !ok { - continue - } - itemType, _ := itemMap["type"].(string) - if itemType != "tool_call" && itemType != "function_call" { - continue - } - if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" { - return true - } - } - return false + return AnalyzeToolContinuationSignals(reqBody).HasToolCallContext } // FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。 // 仅返回非空 call_id,用于与 item_reference.id 做匹配校验。 func FunctionCallOutputCallIDs(reqBody map[string]any) []string { - if reqBody == nil { - return nil - } - input, ok := reqBody["input"].([]any) - if !ok { - return nil - } - ids := make(map[string]struct{}) - for _, item := range input { - itemMap, ok := item.(map[string]any) - if !ok { - continue - } - itemType, _ := itemMap["type"].(string) - if itemType != "function_call_output" { - continue - } - if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" { - ids[callID] = struct{}{} - } - } - if len(ids) == 0 { - return nil - } - result := make([]string, 0, len(ids)) - for id := range ids { - result = append(result, id) - } - return result + return AnalyzeToolContinuationSignals(reqBody).FunctionCallOutputCallIDs } // HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。 func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool { - if reqBody == nil { - return false - } - input, ok := reqBody["input"].([]any) - if !ok { - return false - } - for _, item := range input { - itemMap, ok := item.(map[string]any) - if !ok { - continue - } - itemType, _ := itemMap["type"].(string) - if itemType != "function_call_output" { - continue - } - callID, _ := itemMap["call_id"].(string) - if strings.TrimSpace(callID) == "" { - return true - } - } - return false + return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutputMissingCallID } // HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。 @@ -152,32 +254,13 @@ func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool { return false } for _, callID := range callIDs { - if _, ok := referenceIDs[callID]; !ok { + if _, ok := referenceIDs[strings.TrimSpace(callID)]; !ok { return false } } return true } -// inputHasType 判断 input 中是否存在指定类型的 item。 -func inputHasType(reqBody map[string]any, want string) bool { - input, ok := reqBody["input"].([]any) - if !ok { - return false - } - for _, item := range input { - itemMap, ok := item.(map[string]any) - if !ok { - continue - } - itemType, _ := itemMap["type"].(string) - if itemType == want { - return true - } - } - return false -} - // hasNonEmptyString 判断字段是否为非空字符串。 func hasNonEmptyString(value any) bool { stringValue, ok := value.(string) diff --git a/backend/internal/service/openai_tool_corrector.go b/backend/internal/service/openai_tool_corrector.go index f4719275..348723a6 100644 --- a/backend/internal/service/openai_tool_corrector.go +++ b/backend/internal/service/openai_tool_corrector.go @@ -1,10 +1,15 @@ package service import ( - "encoding/json" + "bytes" "fmt" - "log" + "strconv" + "strings" "sync" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) // codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射 @@ -61,237 +66,273 @@ func (c *CodexToolCorrector) CorrectToolCallsInSSEData(data string) (string, boo if data == "" || data == "\n" { return data, false } + correctedBytes, corrected := c.CorrectToolCallsInSSEBytes([]byte(data)) + if !corrected { + return data, false + } + return string(correctedBytes), true +} - // 尝试解析 JSON - var payload map[string]any - if err := json.Unmarshal([]byte(data), &payload); err != nil { - // 不是有效的 JSON,直接返回原数据 +// CorrectToolCallsInSSEBytes 修正 SSE JSON 数据中的工具调用(字节路径)。 +// 返回修正后的数据和是否进行了修正。 +func (c *CodexToolCorrector) CorrectToolCallsInSSEBytes(data []byte) ([]byte, bool) { + if len(bytes.TrimSpace(data)) == 0 { + return data, false + } + if !mayContainToolCallPayload(data) { + return data, false + } + if !gjson.ValidBytes(data) { + // 不是有效 JSON,直接返回原数据 return data, false } + updated := data corrected := false - - // 处理 tool_calls 数组 - if toolCalls, ok := payload["tool_calls"].([]any); ok { - if c.correctToolCallsArray(toolCalls) { + collect := func(changed bool, next []byte) { + if changed { corrected = true + updated = next } } - // 处理 function_call 对象 - if functionCall, ok := payload["function_call"].(map[string]any); ok { - if c.correctFunctionCall(functionCall) { - corrected = true - } + if next, changed := c.correctToolCallsArrayAtPath(updated, "tool_calls"); changed { + collect(changed, next) + } + if next, changed := c.correctFunctionAtPath(updated, "function_call"); changed { + collect(changed, next) + } + if next, changed := c.correctToolCallsArrayAtPath(updated, "delta.tool_calls"); changed { + collect(changed, next) + } + if next, changed := c.correctFunctionAtPath(updated, "delta.function_call"); changed { + collect(changed, next) } - // 处理 delta.tool_calls - if delta, ok := payload["delta"].(map[string]any); ok { - if toolCalls, ok := delta["tool_calls"].([]any); ok { - if c.correctToolCallsArray(toolCalls) { - corrected = true - } + choicesCount := int(gjson.GetBytes(updated, "choices.#").Int()) + for i := 0; i < choicesCount; i++ { + prefix := "choices." + strconv.Itoa(i) + if next, changed := c.correctToolCallsArrayAtPath(updated, prefix+".message.tool_calls"); changed { + collect(changed, next) } - if functionCall, ok := delta["function_call"].(map[string]any); ok { - if c.correctFunctionCall(functionCall) { - corrected = true - } + if next, changed := c.correctFunctionAtPath(updated, prefix+".message.function_call"); changed { + collect(changed, next) } - } - - // 处理 choices[].message.tool_calls 和 choices[].delta.tool_calls - if choices, ok := payload["choices"].([]any); ok { - for _, choice := range choices { - if choiceMap, ok := choice.(map[string]any); ok { - // 处理 message 中的工具调用 - if message, ok := choiceMap["message"].(map[string]any); ok { - if toolCalls, ok := message["tool_calls"].([]any); ok { - if c.correctToolCallsArray(toolCalls) { - corrected = true - } - } - if functionCall, ok := message["function_call"].(map[string]any); ok { - if c.correctFunctionCall(functionCall) { - corrected = true - } - } - } - // 处理 delta 中的工具调用 - if delta, ok := choiceMap["delta"].(map[string]any); ok { - if toolCalls, ok := delta["tool_calls"].([]any); ok { - if c.correctToolCallsArray(toolCalls) { - corrected = true - } - } - if functionCall, ok := delta["function_call"].(map[string]any); ok { - if c.correctFunctionCall(functionCall) { - corrected = true - } - } - } - } + if next, changed := c.correctToolCallsArrayAtPath(updated, prefix+".delta.tool_calls"); changed { + collect(changed, next) + } + if next, changed := c.correctFunctionAtPath(updated, prefix+".delta.function_call"); changed { + collect(changed, next) } } if !corrected { return data, false } + return updated, true +} - // 序列化回 JSON - correctedBytes, err := json.Marshal(payload) - if err != nil { - log.Printf("[CodexToolCorrector] Failed to marshal corrected data: %v", err) +func mayContainToolCallPayload(data []byte) bool { + // 快速路径:多数 token / 文本事件不包含工具字段,避免进入 JSON 解析热路径。 + return bytes.Contains(data, []byte(`"tool_calls"`)) || + bytes.Contains(data, []byte(`"function_call"`)) || + bytes.Contains(data, []byte(`"function":{"name"`)) +} + +// correctToolCallsArrayAtPath 修正指定路径下 tool_calls 数组中的工具名称。 +func (c *CodexToolCorrector) correctToolCallsArrayAtPath(data []byte, toolCallsPath string) ([]byte, bool) { + count := int(gjson.GetBytes(data, toolCallsPath+".#").Int()) + if count <= 0 { return data, false } - - return string(correctedBytes), true -} - -// correctToolCallsArray 修正工具调用数组中的工具名称 -func (c *CodexToolCorrector) correctToolCallsArray(toolCalls []any) bool { + updated := data corrected := false - for _, toolCall := range toolCalls { - if toolCallMap, ok := toolCall.(map[string]any); ok { - if function, ok := toolCallMap["function"].(map[string]any); ok { - if c.correctFunctionCall(function) { - corrected = true - } - } + for i := 0; i < count; i++ { + functionPath := toolCallsPath + "." + strconv.Itoa(i) + ".function" + if next, changed := c.correctFunctionAtPath(updated, functionPath); changed { + updated = next + corrected = true } } - return corrected + return updated, corrected } -// correctFunctionCall 修正单个函数调用的工具名称和参数 -func (c *CodexToolCorrector) correctFunctionCall(functionCall map[string]any) bool { - name, ok := functionCall["name"].(string) - if !ok || name == "" { - return false +// correctFunctionAtPath 修正指定路径下单个函数调用的工具名称和参数。 +func (c *CodexToolCorrector) correctFunctionAtPath(data []byte, functionPath string) ([]byte, bool) { + namePath := functionPath + ".name" + nameResult := gjson.GetBytes(data, namePath) + if !nameResult.Exists() || nameResult.Type != gjson.String { + return data, false } - + name := strings.TrimSpace(nameResult.Str) + if name == "" { + return data, false + } + updated := data corrected := false // 查找并修正工具名称 if correctName, found := codexToolNameMapping[name]; found { - functionCall["name"] = correctName - c.recordCorrection(name, correctName) - corrected = true - name = correctName // 使用修正后的名称进行参数修正 + if next, err := sjson.SetBytes(updated, namePath, correctName); err == nil { + updated = next + c.recordCorrection(name, correctName) + corrected = true + name = correctName // 使用修正后的名称进行参数修正 + } } // 修正工具参数(基于工具名称) - if c.correctToolParameters(name, functionCall) { + if next, changed := c.correctToolParametersAtPath(updated, functionPath+".arguments", name); changed { + updated = next corrected = true } - - return corrected + return updated, corrected } -// correctToolParameters 修正工具参数以符合 OpenCode 规范 -func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall map[string]any) bool { - arguments, ok := functionCall["arguments"] - if !ok { - return false +// correctToolParametersAtPath 修正指定路径下 arguments 参数。 +func (c *CodexToolCorrector) correctToolParametersAtPath(data []byte, argumentsPath, toolName string) ([]byte, bool) { + if toolName != "bash" && toolName != "edit" { + return data, false } - // arguments 可能是字符串(JSON)或已解析的 map - var argsMap map[string]any - switch v := arguments.(type) { - case string: - // 解析 JSON 字符串 - if err := json.Unmarshal([]byte(v), &argsMap); err != nil { - return false + args := gjson.GetBytes(data, argumentsPath) + if !args.Exists() { + return data, false + } + + switch args.Type { + case gjson.String: + argsJSON := strings.TrimSpace(args.Str) + if !gjson.Valid(argsJSON) { + return data, false } - case map[string]any: - argsMap = v + if !gjson.Parse(argsJSON).IsObject() { + return data, false + } + nextArgsJSON, corrected := c.correctToolArgumentsJSON(argsJSON, toolName) + if !corrected { + return data, false + } + next, err := sjson.SetBytes(data, argumentsPath, nextArgsJSON) + if err != nil { + return data, false + } + return next, true + case gjson.JSON: + if !args.IsObject() || !gjson.Valid(args.Raw) { + return data, false + } + nextArgsJSON, corrected := c.correctToolArgumentsJSON(args.Raw, toolName) + if !corrected { + return data, false + } + next, err := sjson.SetRawBytes(data, argumentsPath, []byte(nextArgsJSON)) + if err != nil { + return data, false + } + return next, true default: - return false + return data, false + } +} + +// correctToolArgumentsJSON 修正工具参数 JSON(对象字符串),返回修正后的 JSON 与是否变更。 +func (c *CodexToolCorrector) correctToolArgumentsJSON(argsJSON, toolName string) (string, bool) { + if !gjson.Valid(argsJSON) { + return argsJSON, false + } + if !gjson.Parse(argsJSON).IsObject() { + return argsJSON, false } + updated := argsJSON corrected := false // 根据工具名称应用特定的参数修正规则 switch toolName { case "bash": // OpenCode bash 支持 workdir;有些来源会输出 work_dir。 - if _, hasWorkdir := argsMap["workdir"]; !hasWorkdir { - if workDir, exists := argsMap["work_dir"]; exists { - argsMap["workdir"] = workDir - delete(argsMap, "work_dir") + if !gjson.Get(updated, "workdir").Exists() { + if next, changed := moveJSONField(updated, "work_dir", "workdir"); changed { + updated = next corrected = true - log.Printf("[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool") + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool") } } else { - if _, exists := argsMap["work_dir"]; exists { - delete(argsMap, "work_dir") + if next, changed := deleteJSONField(updated, "work_dir"); changed { + updated = next corrected = true - log.Printf("[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool") + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool") } } case "edit": // OpenCode edit 参数为 filePath/oldString/newString(camelCase)。 - if _, exists := argsMap["filePath"]; !exists { - if filePath, exists := argsMap["file_path"]; exists { - argsMap["filePath"] = filePath - delete(argsMap, "file_path") + if !gjson.Get(updated, "filePath").Exists() { + if next, changed := moveJSONField(updated, "file_path", "filePath"); changed { + updated = next corrected = true - log.Printf("[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool") - } else if filePath, exists := argsMap["path"]; exists { - argsMap["filePath"] = filePath - delete(argsMap, "path") + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool") + } else if next, changed := moveJSONField(updated, "path", "filePath"); changed { + updated = next corrected = true - log.Printf("[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool") - } else if filePath, exists := argsMap["file"]; exists { - argsMap["filePath"] = filePath - delete(argsMap, "file") + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool") + } else if next, changed := moveJSONField(updated, "file", "filePath"); changed { + updated = next corrected = true - log.Printf("[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool") + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool") } } - if _, exists := argsMap["oldString"]; !exists { - if oldString, exists := argsMap["old_string"]; exists { - argsMap["oldString"] = oldString - delete(argsMap, "old_string") - corrected = true - log.Printf("[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool") - } + if next, changed := moveJSONField(updated, "old_string", "oldString"); changed { + updated = next + corrected = true + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool") } - if _, exists := argsMap["newString"]; !exists { - if newString, exists := argsMap["new_string"]; exists { - argsMap["newString"] = newString - delete(argsMap, "new_string") - corrected = true - log.Printf("[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool") - } + if next, changed := moveJSONField(updated, "new_string", "newString"); changed { + updated = next + corrected = true + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool") } - if _, exists := argsMap["replaceAll"]; !exists { - if replaceAll, exists := argsMap["replace_all"]; exists { - argsMap["replaceAll"] = replaceAll - delete(argsMap, "replace_all") - corrected = true - log.Printf("[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool") - } + if next, changed := moveJSONField(updated, "replace_all", "replaceAll"); changed { + updated = next + corrected = true + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool") } } + return updated, corrected +} - // 如果修正了参数,需要重新序列化 - if corrected { - if _, wasString := arguments.(string); wasString { - // 原本是字符串,序列化回字符串 - if newArgsJSON, err := json.Marshal(argsMap); err == nil { - functionCall["arguments"] = string(newArgsJSON) - } - } else { - // 原本是 map,直接赋值 - functionCall["arguments"] = argsMap - } +func moveJSONField(input, from, to string) (string, bool) { + if gjson.Get(input, to).Exists() { + return input, false } + src := gjson.Get(input, from) + if !src.Exists() { + return input, false + } + next, err := sjson.SetRaw(input, to, src.Raw) + if err != nil { + return input, false + } + next, err = sjson.Delete(next, from) + if err != nil { + return input, false + } + return next, true +} - return corrected +func deleteJSONField(input, path string) (string, bool) { + if !gjson.Get(input, path).Exists() { + return input, false + } + next, err := sjson.Delete(input, path) + if err != nil { + return input, false + } + return next, true } // recordCorrection 记录一次工具名称修正 @@ -303,7 +344,7 @@ func (c *CodexToolCorrector) recordCorrection(from, to string) { key := fmt.Sprintf("%s->%s", from, to) c.stats.CorrectionsByTool[key]++ - log.Printf("[CodexToolCorrector] Corrected tool call: %s -> %s (total: %d)", + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Corrected tool call: %s -> %s (total: %d)", from, to, c.stats.TotalCorrected) } diff --git a/backend/internal/service/openai_tool_corrector_test.go b/backend/internal/service/openai_tool_corrector_test.go index ff518ea6..7c83de9e 100644 --- a/backend/internal/service/openai_tool_corrector_test.go +++ b/backend/internal/service/openai_tool_corrector_test.go @@ -5,6 +5,15 @@ import ( "testing" ) +func TestMayContainToolCallPayload(t *testing.T) { + if mayContainToolCallPayload([]byte(`{"type":"response.output_text.delta","delta":"hello"}`)) { + t.Fatalf("plain text event should not trigger tool-call parsing") + } + if !mayContainToolCallPayload([]byte(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)) { + t.Fatalf("tool_calls event should trigger tool-call parsing") + } +} + func TestCorrectToolCallsInSSEData(t *testing.T) { corrector := NewCodexToolCorrector() diff --git a/backend/internal/service/openai_ws_account_sticky_test.go b/backend/internal/service/openai_ws_account_sticky_test.go new file mode 100644 index 00000000..3fe08179 --- /dev/null +++ b/backend/internal/service/openai_ws_account_sticky_test.go @@ -0,0 +1,190 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T) { + ctx := context.Background() + groupID := int64(23) + account := Account{ + ID: 2, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 2, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + openaiWSStateStore: store, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_1", account.ID, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_1", "gpt-5.1", nil) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, account.ID, selection.Account.ID) + require.True(t, selection.Acquired) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) { + ctx := context.Background() + groupID := int64(23) + account := Account{ + ID: 8, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + openaiWSStateStore: store, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_2", account.ID, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_2", "gpt-5.1", map[int64]struct{}{account.ID: {}}) + require.NoError(t, err) + require.Nil(t, selection) +} + +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_ForceHTTPIgnored(t *testing.T) { + ctx := context.Background() + groupID := int64(23) + account := Account{ + ID: 11, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "openai_ws_force_http": true, + "responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + openaiWSStateStore: store, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_force_http", account.ID, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_force_http", "gpt-5.1", nil) + require.NoError(t, err) + require.Nil(t, selection, "force_http 场景应忽略 previous_response_id 粘连") +} + +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_BusyKeepsSticky(t *testing.T) { + ctx := context.Background() + groupID := int64(23) + accounts := []Account{ + { + ID: 21, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + { + ID: 22, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 9, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + } + + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + cfg.Gateway.Scheduling.StickySessionMaxWaiting = 2 + cfg.Gateway.Scheduling.StickySessionWaitTimeout = 30 * time.Second + + concurrencyCache := stubConcurrencyCache{ + acquireResults: map[int64]bool{ + 21: false, // previous_response 命中的账号繁忙 + 22: true, // 次优账号可用(若回退会命中) + }, + waitCounts: map[int64]int{ + 21: 999, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + openaiWSStateStore: store, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_busy", 21, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_busy", "gpt-5.1", nil) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(21), selection.Account.ID, "busy previous_response sticky account should remain selected") + require.False(t, selection.Acquired) + require.NotNil(t, selection.WaitPlan) + require.Equal(t, int64(21), selection.WaitPlan.AccountID) +} + +func newOpenAIWSV2TestConfig() *config.Config { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + return cfg +} diff --git a/backend/internal/service/openai_ws_client.go b/backend/internal/service/openai_ws_client.go new file mode 100644 index 00000000..9f3c47b7 --- /dev/null +++ b/backend/internal/service/openai_ws_client.go @@ -0,0 +1,285 @@ +package service + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + coderws "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" +) + +const openAIWSMessageReadLimitBytes int64 = 16 * 1024 * 1024 +const ( + openAIWSProxyTransportMaxIdleConns = 128 + openAIWSProxyTransportMaxIdleConnsPerHost = 64 + openAIWSProxyTransportIdleConnTimeout = 90 * time.Second + openAIWSProxyClientCacheMaxEntries = 256 + openAIWSProxyClientCacheIdleTTL = 15 * time.Minute +) + +type OpenAIWSTransportMetricsSnapshot struct { + ProxyClientCacheHits int64 `json:"proxy_client_cache_hits"` + ProxyClientCacheMisses int64 `json:"proxy_client_cache_misses"` + TransportReuseRatio float64 `json:"transport_reuse_ratio"` +} + +// openAIWSClientConn 抽象 WS 客户端连接,便于替换底层实现。 +type openAIWSClientConn interface { + WriteJSON(ctx context.Context, value any) error + ReadMessage(ctx context.Context) ([]byte, error) + Ping(ctx context.Context) error + Close() error +} + +// openAIWSClientDialer 抽象 WS 建连器。 +type openAIWSClientDialer interface { + Dial(ctx context.Context, wsURL string, headers http.Header, proxyURL string) (openAIWSClientConn, int, http.Header, error) +} + +type openAIWSTransportMetricsDialer interface { + SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot +} + +func newDefaultOpenAIWSClientDialer() openAIWSClientDialer { + return &coderOpenAIWSClientDialer{ + proxyClients: make(map[string]*openAIWSProxyClientEntry), + } +} + +type coderOpenAIWSClientDialer struct { + proxyMu sync.Mutex + proxyClients map[string]*openAIWSProxyClientEntry + proxyHits atomic.Int64 + proxyMisses atomic.Int64 +} + +type openAIWSProxyClientEntry struct { + client *http.Client + lastUsedUnixNano int64 +} + +func (d *coderOpenAIWSClientDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + targetURL := strings.TrimSpace(wsURL) + if targetURL == "" { + return nil, 0, nil, errors.New("ws url is empty") + } + + opts := &coderws.DialOptions{ + HTTPHeader: cloneHeader(headers), + CompressionMode: coderws.CompressionContextTakeover, + } + if proxy := strings.TrimSpace(proxyURL); proxy != "" { + proxyClient, err := d.proxyHTTPClient(proxy) + if err != nil { + return nil, 0, nil, err + } + opts.HTTPClient = proxyClient + } + + conn, resp, err := coderws.Dial(ctx, targetURL, opts) + if err != nil { + status := 0 + respHeaders := http.Header(nil) + if resp != nil { + status = resp.StatusCode + respHeaders = cloneHeader(resp.Header) + } + return nil, status, respHeaders, err + } + // coder/websocket 默认单消息读取上限为 32KB,Codex WS 事件(如 rate_limits/大 delta) + // 可能超过该阈值,需显式提高上限,避免本地 read_fail(message too big)。 + conn.SetReadLimit(openAIWSMessageReadLimitBytes) + respHeaders := http.Header(nil) + if resp != nil { + respHeaders = cloneHeader(resp.Header) + } + return &coderOpenAIWSClientConn{conn: conn}, 0, respHeaders, nil +} + +func (d *coderOpenAIWSClientDialer) proxyHTTPClient(proxy string) (*http.Client, error) { + if d == nil { + return nil, errors.New("openai ws dialer is nil") + } + normalizedProxy := strings.TrimSpace(proxy) + if normalizedProxy == "" { + return nil, errors.New("proxy url is empty") + } + parsedProxyURL, err := url.Parse(normalizedProxy) + if err != nil { + return nil, fmt.Errorf("invalid proxy url: %w", err) + } + now := time.Now().UnixNano() + + d.proxyMu.Lock() + defer d.proxyMu.Unlock() + if entry, ok := d.proxyClients[normalizedProxy]; ok && entry != nil && entry.client != nil { + entry.lastUsedUnixNano = now + d.proxyHits.Add(1) + return entry.client, nil + } + d.cleanupProxyClientsLocked(now) + transport := &http.Transport{ + Proxy: http.ProxyURL(parsedProxyURL), + MaxIdleConns: openAIWSProxyTransportMaxIdleConns, + MaxIdleConnsPerHost: openAIWSProxyTransportMaxIdleConnsPerHost, + IdleConnTimeout: openAIWSProxyTransportIdleConnTimeout, + TLSHandshakeTimeout: 10 * time.Second, + ForceAttemptHTTP2: true, + } + client := &http.Client{Transport: transport} + d.proxyClients[normalizedProxy] = &openAIWSProxyClientEntry{ + client: client, + lastUsedUnixNano: now, + } + d.ensureProxyClientCapacityLocked() + d.proxyMisses.Add(1) + return client, nil +} + +func (d *coderOpenAIWSClientDialer) cleanupProxyClientsLocked(nowUnixNano int64) { + if d == nil || len(d.proxyClients) == 0 { + return + } + idleTTL := openAIWSProxyClientCacheIdleTTL + if idleTTL <= 0 { + return + } + now := time.Unix(0, nowUnixNano) + for key, entry := range d.proxyClients { + if entry == nil || entry.client == nil { + delete(d.proxyClients, key) + continue + } + lastUsed := time.Unix(0, entry.lastUsedUnixNano) + if now.Sub(lastUsed) > idleTTL { + closeOpenAIWSProxyClient(entry.client) + delete(d.proxyClients, key) + } + } +} + +func (d *coderOpenAIWSClientDialer) ensureProxyClientCapacityLocked() { + if d == nil { + return + } + maxEntries := openAIWSProxyClientCacheMaxEntries + if maxEntries <= 0 { + return + } + for len(d.proxyClients) > maxEntries { + var oldestKey string + var oldestLastUsed int64 + hasOldest := false + for key, entry := range d.proxyClients { + lastUsed := int64(0) + if entry != nil { + lastUsed = entry.lastUsedUnixNano + } + if !hasOldest || lastUsed < oldestLastUsed { + hasOldest = true + oldestKey = key + oldestLastUsed = lastUsed + } + } + if !hasOldest { + return + } + if entry := d.proxyClients[oldestKey]; entry != nil { + closeOpenAIWSProxyClient(entry.client) + } + delete(d.proxyClients, oldestKey) + } +} + +func closeOpenAIWSProxyClient(client *http.Client) { + if client == nil || client.Transport == nil { + return + } + if transport, ok := client.Transport.(*http.Transport); ok && transport != nil { + transport.CloseIdleConnections() + } +} + +func (d *coderOpenAIWSClientDialer) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot { + if d == nil { + return OpenAIWSTransportMetricsSnapshot{} + } + hits := d.proxyHits.Load() + misses := d.proxyMisses.Load() + total := hits + misses + reuseRatio := 0.0 + if total > 0 { + reuseRatio = float64(hits) / float64(total) + } + return OpenAIWSTransportMetricsSnapshot{ + ProxyClientCacheHits: hits, + ProxyClientCacheMisses: misses, + TransportReuseRatio: reuseRatio, + } +} + +type coderOpenAIWSClientConn struct { + conn *coderws.Conn +} + +func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error { + if c == nil || c.conn == nil { + return errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + return wsjson.Write(ctx, c.conn, value) +} + +func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, error) { + if c == nil || c.conn == nil { + return nil, errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + + msgType, payload, err := c.conn.Read(ctx) + if err != nil { + return nil, err + } + switch msgType { + case coderws.MessageText, coderws.MessageBinary: + return payload, nil + default: + return nil, errOpenAIWSConnClosed + } +} + +func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error { + if c == nil || c.conn == nil { + return errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + return c.conn.Ping(ctx) +} + +func (c *coderOpenAIWSClientConn) Close() error { + if c == nil || c.conn == nil { + return nil + } + // Close 为幂等,忽略重复关闭错误。 + _ = c.conn.Close(coderws.StatusNormalClosure, "") + _ = c.conn.CloseNow() + return nil +} diff --git a/backend/internal/service/openai_ws_client_test.go b/backend/internal/service/openai_ws_client_test.go new file mode 100644 index 00000000..a88d6266 --- /dev/null +++ b/backend/internal/service/openai_ws_client_test.go @@ -0,0 +1,112 @@ +package service + +import ( + "fmt" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestCoderOpenAIWSClientDialer_ProxyHTTPClientReuse(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + c1, err := impl.proxyHTTPClient("http://127.0.0.1:8080") + require.NoError(t, err) + c2, err := impl.proxyHTTPClient("http://127.0.0.1:8080") + require.NoError(t, err) + require.Same(t, c1, c2, "同一代理地址应复用同一个 HTTP 客户端") + + c3, err := impl.proxyHTTPClient("http://127.0.0.1:8081") + require.NoError(t, err) + require.NotSame(t, c1, c3, "不同代理地址应分离客户端") +} + +func TestCoderOpenAIWSClientDialer_ProxyHTTPClientInvalidURL(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + _, err := impl.proxyHTTPClient("://bad") + require.Error(t, err) +} + +func TestCoderOpenAIWSClientDialer_TransportMetricsSnapshot(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + _, err := impl.proxyHTTPClient("http://127.0.0.1:18080") + require.NoError(t, err) + _, err = impl.proxyHTTPClient("http://127.0.0.1:18080") + require.NoError(t, err) + _, err = impl.proxyHTTPClient("http://127.0.0.1:18081") + require.NoError(t, err) + + snapshot := impl.SnapshotTransportMetrics() + require.Equal(t, int64(1), snapshot.ProxyClientCacheHits) + require.Equal(t, int64(2), snapshot.ProxyClientCacheMisses) + require.InDelta(t, 1.0/3.0, snapshot.TransportReuseRatio, 0.0001) +} + +func TestCoderOpenAIWSClientDialer_ProxyClientCacheCapacity(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + total := openAIWSProxyClientCacheMaxEntries + 32 + for i := 0; i < total; i++ { + _, err := impl.proxyHTTPClient(fmt.Sprintf("http://127.0.0.1:%d", 20000+i)) + require.NoError(t, err) + } + + impl.proxyMu.Lock() + cacheSize := len(impl.proxyClients) + impl.proxyMu.Unlock() + + require.LessOrEqual(t, cacheSize, openAIWSProxyClientCacheMaxEntries, "代理客户端缓存应受容量上限约束") +} + +func TestCoderOpenAIWSClientDialer_ProxyClientCacheIdleTTL(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + oldProxy := "http://127.0.0.1:28080" + _, err := impl.proxyHTTPClient(oldProxy) + require.NoError(t, err) + + impl.proxyMu.Lock() + oldEntry := impl.proxyClients[oldProxy] + require.NotNil(t, oldEntry) + oldEntry.lastUsedUnixNano = time.Now().Add(-openAIWSProxyClientCacheIdleTTL - time.Minute).UnixNano() + impl.proxyMu.Unlock() + + // 触发一次新的代理获取,驱动 TTL 清理。 + _, err = impl.proxyHTTPClient("http://127.0.0.1:28081") + require.NoError(t, err) + + impl.proxyMu.Lock() + _, exists := impl.proxyClients[oldProxy] + impl.proxyMu.Unlock() + + require.False(t, exists, "超过空闲 TTL 的代理客户端应被回收") +} + +func TestCoderOpenAIWSClientDialer_ProxyTransportTLSHandshakeTimeout(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + client, err := impl.proxyHTTPClient("http://127.0.0.1:38080") + require.NoError(t, err) + require.NotNil(t, client) + + transport, ok := client.Transport.(*http.Transport) + require.True(t, ok) + require.NotNil(t, transport) + require.Equal(t, 10*time.Second, transport.TLSHandshakeTimeout) +} diff --git a/backend/internal/service/openai_ws_fallback_test.go b/backend/internal/service/openai_ws_fallback_test.go new file mode 100644 index 00000000..ce06f6a2 --- /dev/null +++ b/backend/internal/service/openai_ws_fallback_test.go @@ -0,0 +1,251 @@ +package service + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + coderws "github.com/coder/websocket" + "github.com/stretchr/testify/require" +) + +func TestClassifyOpenAIWSAcquireError(t *testing.T) { + t.Run("dial_426_upgrade_required", func(t *testing.T) { + err := &openAIWSDialError{StatusCode: 426, Err: errors.New("upgrade required")} + require.Equal(t, "upgrade_required", classifyOpenAIWSAcquireError(err)) + }) + + t.Run("queue_full", func(t *testing.T) { + require.Equal(t, "conn_queue_full", classifyOpenAIWSAcquireError(errOpenAIWSConnQueueFull)) + }) + + t.Run("preferred_conn_unavailable", func(t *testing.T) { + require.Equal(t, "preferred_conn_unavailable", classifyOpenAIWSAcquireError(errOpenAIWSPreferredConnUnavailable)) + }) + + t.Run("acquire_timeout", func(t *testing.T) { + require.Equal(t, "acquire_timeout", classifyOpenAIWSAcquireError(context.DeadlineExceeded)) + }) + + t.Run("auth_failed_401", func(t *testing.T) { + err := &openAIWSDialError{StatusCode: 401, Err: errors.New("unauthorized")} + require.Equal(t, "auth_failed", classifyOpenAIWSAcquireError(err)) + }) + + t.Run("upstream_rate_limited", func(t *testing.T) { + err := &openAIWSDialError{StatusCode: 429, Err: errors.New("rate limited")} + require.Equal(t, "upstream_rate_limited", classifyOpenAIWSAcquireError(err)) + }) + + t.Run("upstream_5xx", func(t *testing.T) { + err := &openAIWSDialError{StatusCode: 502, Err: errors.New("bad gateway")} + require.Equal(t, "upstream_5xx", classifyOpenAIWSAcquireError(err)) + }) + + t.Run("dial_failed_other_status", func(t *testing.T) { + err := &openAIWSDialError{StatusCode: 418, Err: errors.New("teapot")} + require.Equal(t, "dial_failed", classifyOpenAIWSAcquireError(err)) + }) + + t.Run("other", func(t *testing.T) { + require.Equal(t, "acquire_conn", classifyOpenAIWSAcquireError(errors.New("x"))) + }) + + t.Run("nil", func(t *testing.T) { + require.Equal(t, "acquire_conn", classifyOpenAIWSAcquireError(nil)) + }) +} + +func TestClassifyOpenAIWSDialError(t *testing.T) { + t.Run("handshake_not_finished", func(t *testing.T) { + err := &openAIWSDialError{ + StatusCode: http.StatusBadGateway, + Err: errors.New("WebSocket protocol error: Handshake not finished"), + } + require.Equal(t, "handshake_not_finished", classifyOpenAIWSDialError(err)) + }) + + t.Run("context_deadline", func(t *testing.T) { + err := &openAIWSDialError{ + StatusCode: 0, + Err: context.DeadlineExceeded, + } + require.Equal(t, "ctx_deadline_exceeded", classifyOpenAIWSDialError(err)) + }) +} + +func TestSummarizeOpenAIWSDialError(t *testing.T) { + err := &openAIWSDialError{ + StatusCode: http.StatusBadGateway, + ResponseHeaders: http.Header{ + "Server": []string{"cloudflare"}, + "Via": []string{"1.1 example"}, + "Cf-Ray": []string{"abcd1234"}, + "X-Request-Id": []string{"req_123"}, + }, + Err: errors.New("WebSocket protocol error: Handshake not finished"), + } + + status, class, closeStatus, closeReason, server, via, cfRay, reqID := summarizeOpenAIWSDialError(err) + require.Equal(t, http.StatusBadGateway, status) + require.Equal(t, "handshake_not_finished", class) + require.Equal(t, "-", closeStatus) + require.Equal(t, "-", closeReason) + require.Equal(t, "cloudflare", server) + require.Equal(t, "1.1 example", via) + require.Equal(t, "abcd1234", cfRay) + require.Equal(t, "req_123", reqID) +} + +func TestClassifyOpenAIWSErrorEvent(t *testing.T) { + reason, recoverable := classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"upgrade_required","message":"Upgrade required"}}`)) + require.Equal(t, "upgrade_required", reason) + require.True(t, recoverable) + + reason, recoverable = classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"previous_response_not_found","message":"not found"}}`)) + require.Equal(t, "previous_response_not_found", reason) + require.True(t, recoverable) +} + +func TestClassifyOpenAIWSReconnectReason(t *testing.T) { + reason, retryable := classifyOpenAIWSReconnectReason(wrapOpenAIWSFallback("policy_violation", errors.New("policy"))) + require.Equal(t, "policy_violation", reason) + require.False(t, retryable) + + reason, retryable = classifyOpenAIWSReconnectReason(wrapOpenAIWSFallback("read_event", errors.New("io"))) + require.Equal(t, "read_event", reason) + require.True(t, retryable) +} + +func TestOpenAIWSErrorHTTPStatus(t *testing.T) { + require.Equal(t, http.StatusBadRequest, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`))) + require.Equal(t, http.StatusUnauthorized, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"authentication_error","code":"invalid_api_key","message":"auth failed"}}`))) + require.Equal(t, http.StatusForbidden, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"permission_error","code":"forbidden","message":"forbidden"}}`))) + require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"rate_limit_error","code":"rate_limit_exceeded","message":"rate limited"}}`))) + require.Equal(t, http.StatusBadGateway, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"server_error","code":"server_error","message":"server"}}`))) +} + +func TestResolveOpenAIWSFallbackErrorResponse(t *testing.T) { + t.Run("previous_response_not_found", func(t *testing.T) { + statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse( + wrapOpenAIWSFallback("previous_response_not_found", errors.New("previous response not found")), + ) + require.True(t, ok) + require.Equal(t, http.StatusBadRequest, statusCode) + require.Equal(t, "invalid_request_error", errType) + require.Equal(t, "previous response not found", clientMessage) + require.Equal(t, "previous response not found", upstreamMessage) + }) + + t.Run("auth_failed_uses_dial_status", func(t *testing.T) { + statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse( + wrapOpenAIWSFallback("auth_failed", &openAIWSDialError{ + StatusCode: http.StatusForbidden, + Err: errors.New("forbidden"), + }), + ) + require.True(t, ok) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, "upstream_error", errType) + require.Equal(t, "forbidden", clientMessage) + require.Equal(t, "forbidden", upstreamMessage) + }) + + t.Run("non_fallback_error_not_resolved", func(t *testing.T) { + _, _, _, _, ok := resolveOpenAIWSFallbackErrorResponse(errors.New("plain error")) + require.False(t, ok) + }) +} + +func TestOpenAIWSFallbackCooling(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{}} + svc.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + require.False(t, svc.isOpenAIWSFallbackCooling(1)) + svc.markOpenAIWSFallbackCooling(1, "upgrade_required") + require.True(t, svc.isOpenAIWSFallbackCooling(1)) + + svc.clearOpenAIWSFallbackCooling(1) + require.False(t, svc.isOpenAIWSFallbackCooling(1)) + + svc.markOpenAIWSFallbackCooling(2, "x") + time.Sleep(1200 * time.Millisecond) + require.False(t, svc.isOpenAIWSFallbackCooling(2)) +} + +func TestOpenAIWSRetryBackoff(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{}} + svc.cfg.Gateway.OpenAIWS.RetryBackoffInitialMS = 100 + svc.cfg.Gateway.OpenAIWS.RetryBackoffMaxMS = 400 + svc.cfg.Gateway.OpenAIWS.RetryJitterRatio = 0 + + require.Equal(t, time.Duration(100)*time.Millisecond, svc.openAIWSRetryBackoff(1)) + require.Equal(t, time.Duration(200)*time.Millisecond, svc.openAIWSRetryBackoff(2)) + require.Equal(t, time.Duration(400)*time.Millisecond, svc.openAIWSRetryBackoff(3)) + require.Equal(t, time.Duration(400)*time.Millisecond, svc.openAIWSRetryBackoff(4)) +} + +func TestOpenAIWSRetryTotalBudget(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{}} + svc.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS = 1200 + require.Equal(t, 1200*time.Millisecond, svc.openAIWSRetryTotalBudget()) + + svc.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS = 0 + require.Equal(t, time.Duration(0), svc.openAIWSRetryTotalBudget()) +} + +func TestClassifyOpenAIWSReadFallbackReason(t *testing.T) { + require.Equal(t, "policy_violation", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusPolicyViolation})) + require.Equal(t, "message_too_big", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusMessageTooBig})) + require.Equal(t, "read_event", classifyOpenAIWSReadFallbackReason(errors.New("io"))) +} + +func TestOpenAIWSStoreDisabledConnMode(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{}} + svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true + require.Equal(t, openAIWSStoreDisabledConnModeStrict, svc.openAIWSStoreDisabledConnMode()) + + svc.cfg.Gateway.OpenAIWS.StoreDisabledConnMode = "adaptive" + require.Equal(t, openAIWSStoreDisabledConnModeAdaptive, svc.openAIWSStoreDisabledConnMode()) + + svc.cfg.Gateway.OpenAIWS.StoreDisabledConnMode = "" + svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = false + require.Equal(t, openAIWSStoreDisabledConnModeOff, svc.openAIWSStoreDisabledConnMode()) +} + +func TestShouldForceNewConnOnStoreDisabled(t *testing.T) { + require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeStrict, "")) + require.False(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeOff, "policy_violation")) + + require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "policy_violation")) + require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "prewarm_message_too_big")) + require.False(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "read_event")) +} + +func TestOpenAIWSRetryMetricsSnapshot(t *testing.T) { + svc := &OpenAIGatewayService{} + svc.recordOpenAIWSRetryAttempt(150 * time.Millisecond) + svc.recordOpenAIWSRetryAttempt(0) + svc.recordOpenAIWSRetryExhausted() + svc.recordOpenAIWSNonRetryableFastFallback() + + snapshot := svc.SnapshotOpenAIWSRetryMetrics() + require.Equal(t, int64(2), snapshot.RetryAttemptsTotal) + require.Equal(t, int64(150), snapshot.RetryBackoffMsTotal) + require.Equal(t, int64(1), snapshot.RetryExhaustedTotal) + require.Equal(t, int64(1), snapshot.NonRetryableFastFallbackTotal) +} + +func TestShouldLogOpenAIWSPayloadSchema(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{}} + + svc.cfg.Gateway.OpenAIWS.PayloadLogSampleRate = 0 + require.True(t, svc.shouldLogOpenAIWSPayloadSchema(1), "首次尝试应始终记录 payload_schema") + require.False(t, svc.shouldLogOpenAIWSPayloadSchema(2)) + + svc.cfg.Gateway.OpenAIWS.PayloadLogSampleRate = 1 + require.True(t, svc.shouldLogOpenAIWSPayloadSchema(2)) +} diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go new file mode 100644 index 00000000..74ba472f --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder.go @@ -0,0 +1,3955 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "math/rand" + "net" + "net/http" + "net/url" + "sort" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.uber.org/zap" +) + +const ( + openAIWSBetaV1Value = "responses_websockets=2026-02-04" + openAIWSBetaV2Value = "responses_websockets=2026-02-06" + + openAIWSTurnStateHeader = "x-codex-turn-state" + openAIWSTurnMetadataHeader = "x-codex-turn-metadata" + + openAIWSLogValueMaxLen = 160 + openAIWSHeaderValueMaxLen = 120 + openAIWSIDValueMaxLen = 64 + openAIWSEventLogHeadLimit = 20 + openAIWSEventLogEveryN = 50 + openAIWSBufferLogHeadLimit = 8 + openAIWSBufferLogEveryN = 20 + openAIWSPrewarmEventLogHead = 10 + openAIWSPayloadKeySizeTopN = 6 + + openAIWSPayloadSizeEstimateDepth = 3 + openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024 + openAIWSPayloadSizeEstimateMaxItems = 16 + + openAIWSEventFlushBatchSizeDefault = 4 + openAIWSEventFlushIntervalDefault = 25 * time.Millisecond + openAIWSPayloadLogSampleDefault = 0.2 + + openAIWSStoreDisabledConnModeStrict = "strict" + openAIWSStoreDisabledConnModeAdaptive = "adaptive" + openAIWSStoreDisabledConnModeOff = "off" + + openAIWSIngressStagePreviousResponseNotFound = "previous_response_not_found" + openAIWSMaxPrevResponseIDDeletePasses = 8 +) + +var openAIWSLogValueReplacer = strings.NewReplacer( + "error", "err", + "fallback", "fb", + "warning", "warnx", + "failed", "fail", +) + +var openAIWSIngressPreflightPingIdle = 20 * time.Second + +// openAIWSFallbackError 表示可安全回退到 HTTP 的 WS 错误(尚未写下游)。 +type openAIWSFallbackError struct { + Reason string + Err error +} + +func (e *openAIWSFallbackError) Error() string { + if e == nil { + return "" + } + if e.Err == nil { + return fmt.Sprintf("openai ws fallback: %s", strings.TrimSpace(e.Reason)) + } + return fmt.Sprintf("openai ws fallback: %s: %v", strings.TrimSpace(e.Reason), e.Err) +} + +func (e *openAIWSFallbackError) Unwrap() error { + if e == nil { + return nil + } + return e.Err +} + +func wrapOpenAIWSFallback(reason string, err error) error { + return &openAIWSFallbackError{Reason: strings.TrimSpace(reason), Err: err} +} + +// OpenAIWSClientCloseError 表示应以指定 WebSocket close code 主动关闭客户端连接的错误。 +type OpenAIWSClientCloseError struct { + statusCode coderws.StatusCode + reason string + err error +} + +type openAIWSIngressTurnError struct { + stage string + cause error + wroteDownstream bool +} + +func (e *openAIWSIngressTurnError) Error() string { + if e == nil { + return "" + } + if e.cause == nil { + return strings.TrimSpace(e.stage) + } + return e.cause.Error() +} + +func (e *openAIWSIngressTurnError) Unwrap() error { + if e == nil { + return nil + } + return e.cause +} + +func wrapOpenAIWSIngressTurnError(stage string, cause error, wroteDownstream bool) error { + if cause == nil { + return nil + } + return &openAIWSIngressTurnError{ + stage: strings.TrimSpace(stage), + cause: cause, + wroteDownstream: wroteDownstream, + } +} + +func isOpenAIWSIngressTurnRetryable(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + if errors.Is(turnErr.cause, context.Canceled) || errors.Is(turnErr.cause, context.DeadlineExceeded) { + return false + } + if turnErr.wroteDownstream { + return false + } + switch turnErr.stage { + case "write_upstream", "read_upstream": + return true + default: + return false + } +} + +func openAIWSIngressTurnRetryReason(err error) string { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return "unknown" + } + if turnErr.stage == "" { + return "unknown" + } + return turnErr.stage +} + +func isOpenAIWSIngressPreviousResponseNotFound(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + if strings.TrimSpace(turnErr.stage) != openAIWSIngressStagePreviousResponseNotFound { + return false + } + return !turnErr.wroteDownstream +} + +// NewOpenAIWSClientCloseError 创建一个客户端 WS 关闭错误。 +func NewOpenAIWSClientCloseError(statusCode coderws.StatusCode, reason string, err error) error { + return &OpenAIWSClientCloseError{ + statusCode: statusCode, + reason: strings.TrimSpace(reason), + err: err, + } +} + +func (e *OpenAIWSClientCloseError) Error() string { + if e == nil { + return "" + } + if e.err == nil { + return fmt.Sprintf("openai ws client close: %d %s", int(e.statusCode), strings.TrimSpace(e.reason)) + } + return fmt.Sprintf("openai ws client close: %d %s: %v", int(e.statusCode), strings.TrimSpace(e.reason), e.err) +} + +func (e *OpenAIWSClientCloseError) Unwrap() error { + if e == nil { + return nil + } + return e.err +} + +func (e *OpenAIWSClientCloseError) StatusCode() coderws.StatusCode { + if e == nil { + return coderws.StatusInternalError + } + return e.statusCode +} + +func (e *OpenAIWSClientCloseError) Reason() string { + if e == nil { + return "" + } + return strings.TrimSpace(e.reason) +} + +// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。 +type OpenAIWSIngressHooks struct { + BeforeTurn func(turn int) error + AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error) +} + +func normalizeOpenAIWSLogValue(value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "-" + } + return openAIWSLogValueReplacer.Replace(trimmed) +} + +func truncateOpenAIWSLogValue(value string, maxLen int) string { + normalized := normalizeOpenAIWSLogValue(value) + if normalized == "-" || maxLen <= 0 { + return normalized + } + if len(normalized) <= maxLen { + return normalized + } + return normalized[:maxLen] + "..." +} + +func openAIWSHeaderValueForLog(headers http.Header, key string) string { + if headers == nil { + return "-" + } + return truncateOpenAIWSLogValue(headers.Get(key), openAIWSHeaderValueMaxLen) +} + +func hasOpenAIWSHeader(headers http.Header, key string) bool { + if headers == nil { + return false + } + return strings.TrimSpace(headers.Get(key)) != "" +} + +type openAIWSSessionHeaderResolution struct { + SessionID string + ConversationID string + SessionSource string + ConversationSource string +} + +func resolveOpenAIWSSessionHeaders(c *gin.Context, promptCacheKey string) openAIWSSessionHeaderResolution { + resolution := openAIWSSessionHeaderResolution{ + SessionSource: "none", + ConversationSource: "none", + } + if c != nil && c.Request != nil { + if sessionID := strings.TrimSpace(c.Request.Header.Get("session_id")); sessionID != "" { + resolution.SessionID = sessionID + resolution.SessionSource = "header_session_id" + } + if conversationID := strings.TrimSpace(c.Request.Header.Get("conversation_id")); conversationID != "" { + resolution.ConversationID = conversationID + resolution.ConversationSource = "header_conversation_id" + if resolution.SessionID == "" { + resolution.SessionID = conversationID + resolution.SessionSource = "header_conversation_id" + } + } + } + + cacheKey := strings.TrimSpace(promptCacheKey) + if cacheKey != "" { + if resolution.SessionID == "" { + resolution.SessionID = cacheKey + resolution.SessionSource = "prompt_cache_key" + } + } + return resolution +} + +func shouldLogOpenAIWSEvent(idx int, eventType string) bool { + if idx <= openAIWSEventLogHeadLimit { + return true + } + if openAIWSEventLogEveryN > 0 && idx%openAIWSEventLogEveryN == 0 { + return true + } + if eventType == "error" || isOpenAIWSTerminalEvent(eventType) { + return true + } + return false +} + +func shouldLogOpenAIWSBufferedEvent(idx int) bool { + if idx <= openAIWSBufferLogHeadLimit { + return true + } + if openAIWSBufferLogEveryN > 0 && idx%openAIWSBufferLogEveryN == 0 { + return true + } + return false +} + +func openAIWSEventMayContainModel(eventType string) bool { + switch eventType { + case "response.created", + "response.in_progress", + "response.completed", + "response.done", + "response.failed", + "response.incomplete", + "response.cancelled", + "response.canceled": + return true + default: + trimmed := strings.TrimSpace(eventType) + if trimmed == eventType { + return false + } + switch trimmed { + case "response.created", + "response.in_progress", + "response.completed", + "response.done", + "response.failed", + "response.incomplete", + "response.cancelled", + "response.canceled": + return true + default: + return false + } + } +} + +func openAIWSEventMayContainToolCalls(eventType string) bool { + eventType = strings.TrimSpace(eventType) + if eventType == "" { + return false + } + if strings.Contains(eventType, "function_call") || strings.Contains(eventType, "tool_call") { + return true + } + switch eventType { + case "response.output_item.added", "response.output_item.done", "response.completed", "response.done": + return true + default: + return false + } +} + +func openAIWSEventShouldParseUsage(eventType string) bool { + return eventType == "response.completed" || strings.TrimSpace(eventType) == "response.completed" +} + +func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) { + if len(message) == 0 { + return "", "", gjson.Result{} + } + values := gjson.GetManyBytes(message, "type", "response.id", "id", "response") + eventType = strings.TrimSpace(values[0].String()) + if id := strings.TrimSpace(values[1].String()); id != "" { + responseID = id + } else { + responseID = strings.TrimSpace(values[2].String()) + } + return eventType, responseID, values[3] +} + +func openAIWSMessageLikelyContainsToolCalls(message []byte) bool { + if len(message) == 0 { + return false + } + return bytes.Contains(message, []byte(`"tool_calls"`)) || + bytes.Contains(message, []byte(`"tool_call"`)) || + bytes.Contains(message, []byte(`"function_call"`)) +} + +func parseOpenAIWSResponseUsageFromCompletedEvent(message []byte, usage *OpenAIUsage) { + if usage == nil || len(message) == 0 { + return + } + values := gjson.GetManyBytes( + message, + "response.usage.input_tokens", + "response.usage.output_tokens", + "response.usage.input_tokens_details.cached_tokens", + ) + usage.InputTokens = int(values[0].Int()) + usage.OutputTokens = int(values[1].Int()) + usage.CacheReadInputTokens = int(values[2].Int()) +} + +func parseOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) { + if len(message) == 0 { + return "", "", "" + } + values := gjson.GetManyBytes(message, "error.code", "error.type", "error.message") + return strings.TrimSpace(values[0].String()), strings.TrimSpace(values[1].String()), strings.TrimSpace(values[2].String()) +} + +func summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMessageRaw string) (code string, errType string, errMessage string) { + code = truncateOpenAIWSLogValue(codeRaw, openAIWSLogValueMaxLen) + errType = truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen) + errMessage = truncateOpenAIWSLogValue(errMessageRaw, openAIWSLogValueMaxLen) + return code, errType, errMessage +} + +func summarizeOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) { + if len(message) == 0 { + return "-", "-", "-" + } + return summarizeOpenAIWSErrorEventFieldsFromRaw(parseOpenAIWSErrorEventFields(message)) +} + +func summarizeOpenAIWSPayloadKeySizes(payload map[string]any, topN int) string { + if len(payload) == 0 { + return "-" + } + type keySize struct { + Key string + Size int + } + sizes := make([]keySize, 0, len(payload)) + for key, value := range payload { + size := estimateOpenAIWSPayloadValueSize(value, openAIWSPayloadSizeEstimateDepth) + sizes = append(sizes, keySize{Key: key, Size: size}) + } + sort.Slice(sizes, func(i, j int) bool { + if sizes[i].Size == sizes[j].Size { + return sizes[i].Key < sizes[j].Key + } + return sizes[i].Size > sizes[j].Size + }) + + if topN <= 0 || topN > len(sizes) { + topN = len(sizes) + } + parts := make([]string, 0, topN) + for idx := 0; idx < topN; idx++ { + item := sizes[idx] + parts = append(parts, fmt.Sprintf("%s:%d", item.Key, item.Size)) + } + return strings.Join(parts, ",") +} + +func estimateOpenAIWSPayloadValueSize(value any, depth int) int { + if depth <= 0 { + return -1 + } + switch v := value.(type) { + case nil: + return 0 + case string: + return len(v) + case []byte: + return len(v) + case bool: + return 1 + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return 8 + case float32, float64: + return 8 + case map[string]any: + if len(v) == 0 { + return 2 + } + total := 2 + count := 0 + for key, item := range v { + count++ + if count > openAIWSPayloadSizeEstimateMaxItems { + return -1 + } + itemSize := estimateOpenAIWSPayloadValueSize(item, depth-1) + if itemSize < 0 { + return -1 + } + total += len(key) + itemSize + 3 + if total > openAIWSPayloadSizeEstimateMaxBytes { + return -1 + } + } + return total + case []any: + if len(v) == 0 { + return 2 + } + total := 2 + limit := len(v) + if limit > openAIWSPayloadSizeEstimateMaxItems { + return -1 + } + for i := 0; i < limit; i++ { + itemSize := estimateOpenAIWSPayloadValueSize(v[i], depth-1) + if itemSize < 0 { + return -1 + } + total += itemSize + 1 + if total > openAIWSPayloadSizeEstimateMaxBytes { + return -1 + } + } + return total + default: + raw, err := json.Marshal(v) + if err != nil { + return -1 + } + if len(raw) > openAIWSPayloadSizeEstimateMaxBytes { + return -1 + } + return len(raw) + } +} + +func openAIWSPayloadString(payload map[string]any, key string) string { + if len(payload) == 0 { + return "" + } + raw, ok := payload[key] + if !ok { + return "" + } + switch v := raw.(type) { + case nil: + return "" + case string: + return strings.TrimSpace(v) + case []byte: + return strings.TrimSpace(string(v)) + default: + return "" + } +} + +func openAIWSPayloadStringFromRaw(payload []byte, key string) string { + if len(payload) == 0 || strings.TrimSpace(key) == "" { + return "" + } + return strings.TrimSpace(gjson.GetBytes(payload, key).String()) +} + +func openAIWSPayloadBoolFromRaw(payload []byte, key string, defaultValue bool) bool { + if len(payload) == 0 || strings.TrimSpace(key) == "" { + return defaultValue + } + value := gjson.GetBytes(payload, key) + if !value.Exists() { + return defaultValue + } + if value.Type != gjson.True && value.Type != gjson.False { + return defaultValue + } + return value.Bool() +} + +func openAIWSSessionHashesFromID(sessionID string) (string, string) { + return deriveOpenAISessionHashes(sessionID) +} + +func extractOpenAIWSImageURL(value any) string { + switch v := value.(type) { + case string: + return strings.TrimSpace(v) + case map[string]any: + if raw, ok := v["url"].(string); ok { + return strings.TrimSpace(raw) + } + } + return "" +} + +func summarizeOpenAIWSInput(input any) string { + items, ok := input.([]any) + if !ok || len(items) == 0 { + return "-" + } + + itemCount := len(items) + textChars := 0 + imageDataURLs := 0 + imageDataURLChars := 0 + imageRemoteURLs := 0 + + handleContentItem := func(contentItem map[string]any) { + contentType, _ := contentItem["type"].(string) + switch strings.TrimSpace(contentType) { + case "input_text", "output_text", "text": + if text, ok := contentItem["text"].(string); ok { + textChars += len(text) + } + case "input_image": + imageURL := extractOpenAIWSImageURL(contentItem["image_url"]) + if imageURL == "" { + return + } + if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") { + imageDataURLs++ + imageDataURLChars += len(imageURL) + return + } + imageRemoteURLs++ + } + } + + handleInputItem := func(inputItem map[string]any) { + if content, ok := inputItem["content"].([]any); ok { + for _, rawContent := range content { + contentItem, ok := rawContent.(map[string]any) + if !ok { + continue + } + handleContentItem(contentItem) + } + return + } + + itemType, _ := inputItem["type"].(string) + switch strings.TrimSpace(itemType) { + case "input_text", "output_text", "text": + if text, ok := inputItem["text"].(string); ok { + textChars += len(text) + } + case "input_image": + imageURL := extractOpenAIWSImageURL(inputItem["image_url"]) + if imageURL == "" { + return + } + if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") { + imageDataURLs++ + imageDataURLChars += len(imageURL) + return + } + imageRemoteURLs++ + } + } + + for _, rawItem := range items { + inputItem, ok := rawItem.(map[string]any) + if !ok { + continue + } + handleInputItem(inputItem) + } + + return fmt.Sprintf( + "items=%d,text_chars=%d,image_data_urls=%d,image_data_url_chars=%d,image_remote_urls=%d", + itemCount, + textChars, + imageDataURLs, + imageDataURLChars, + imageRemoteURLs, + ) +} + +func dropOpenAIWSPayloadKey(payload map[string]any, key string, removed *[]string) { + if len(payload) == 0 || strings.TrimSpace(key) == "" { + return + } + if _, exists := payload[key]; !exists { + return + } + delete(payload, key) + *removed = append(*removed, key) +} + +// applyOpenAIWSRetryPayloadStrategy 在 WS 连续失败时仅移除无语义字段, +// 避免重试成功却改变原始请求语义。 +// 注意:prompt_cache_key 不应在重试中移除;它常用于会话稳定标识(session_id 兜底)。 +func applyOpenAIWSRetryPayloadStrategy(payload map[string]any, attempt int) (strategy string, removedKeys []string) { + if len(payload) == 0 { + return "empty", nil + } + if attempt <= 1 { + return "full", nil + } + + removed := make([]string, 0, 2) + if attempt >= 2 { + dropOpenAIWSPayloadKey(payload, "include", &removed) + } + + if len(removed) == 0 { + return "full", nil + } + sort.Strings(removed) + return "trim_optional_fields", removed +} + +func logOpenAIWSModeInfo(format string, args ...any) { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode][openai_ws_mode=true] "+format, args...) +} + +func isOpenAIWSModeDebugEnabled() bool { + return logger.L().Core().Enabled(zap.DebugLevel) +} + +func logOpenAIWSModeDebug(format string, args ...any) { + if !isOpenAIWSModeDebugEnabled() { + return + } + logger.LegacyPrintf("service.openai_gateway", "[debug] [OpenAI WS Mode][openai_ws_mode=true] "+format, args...) +} + +func logOpenAIWSBindResponseAccountWarn(groupID, accountID int64, responseID string, err error) { + if err == nil { + return + } + logger.L().Warn( + "openai.ws_bind_response_account_failed", + zap.Int64("group_id", groupID), + zap.Int64("account_id", accountID), + zap.String("response_id", truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen)), + zap.Error(err), + ) +} + +func summarizeOpenAIWSReadCloseError(err error) (status string, reason string) { + if err == nil { + return "-", "-" + } + statusCode := coderws.CloseStatus(err) + if statusCode == -1 { + return "-", "-" + } + closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String()) + closeReason := "-" + var closeErr coderws.CloseError + if errors.As(err, &closeErr) { + reasonText := strings.TrimSpace(closeErr.Reason) + if reasonText != "" { + closeReason = normalizeOpenAIWSLogValue(reasonText) + } + } + return normalizeOpenAIWSLogValue(closeStatus), closeReason +} + +func unwrapOpenAIWSDialBaseError(err error) error { + if err == nil { + return nil + } + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) && dialErr != nil && dialErr.Err != nil { + return dialErr.Err + } + return err +} + +func openAIWSDialRespHeaderForLog(err error, key string) string { + var dialErr *openAIWSDialError + if !errors.As(err, &dialErr) || dialErr == nil || dialErr.ResponseHeaders == nil { + return "-" + } + return truncateOpenAIWSLogValue(dialErr.ResponseHeaders.Get(key), openAIWSHeaderValueMaxLen) +} + +func classifyOpenAIWSDialError(err error) string { + if err == nil { + return "-" + } + baseErr := unwrapOpenAIWSDialBaseError(err) + if baseErr == nil { + return "-" + } + if errors.Is(baseErr, context.DeadlineExceeded) { + return "ctx_deadline_exceeded" + } + if errors.Is(baseErr, context.Canceled) { + return "ctx_canceled" + } + var netErr net.Error + if errors.As(baseErr, &netErr) && netErr.Timeout() { + return "net_timeout" + } + if status := coderws.CloseStatus(baseErr); status != -1 { + return normalizeOpenAIWSLogValue(fmt.Sprintf("ws_close_%d", int(status))) + } + message := strings.ToLower(strings.TrimSpace(baseErr.Error())) + switch { + case strings.Contains(message, "handshake not finished"): + return "handshake_not_finished" + case strings.Contains(message, "bad handshake"): + return "bad_handshake" + case strings.Contains(message, "connection refused"): + return "connection_refused" + case strings.Contains(message, "no such host"): + return "dns_not_found" + case strings.Contains(message, "tls"): + return "tls_error" + case strings.Contains(message, "i/o timeout"): + return "io_timeout" + case strings.Contains(message, "context deadline exceeded"): + return "ctx_deadline_exceeded" + default: + return "dial_error" + } +} + +func summarizeOpenAIWSDialError(err error) ( + statusCode int, + dialClass string, + closeStatus string, + closeReason string, + respServer string, + respVia string, + respCFRay string, + respRequestID string, +) { + dialClass = "-" + closeStatus = "-" + closeReason = "-" + respServer = "-" + respVia = "-" + respCFRay = "-" + respRequestID = "-" + if err == nil { + return + } + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) && dialErr != nil { + statusCode = dialErr.StatusCode + respServer = openAIWSDialRespHeaderForLog(err, "server") + respVia = openAIWSDialRespHeaderForLog(err, "via") + respCFRay = openAIWSDialRespHeaderForLog(err, "cf-ray") + respRequestID = openAIWSDialRespHeaderForLog(err, "x-request-id") + } + dialClass = normalizeOpenAIWSLogValue(classifyOpenAIWSDialError(err)) + closeStatus, closeReason = summarizeOpenAIWSReadCloseError(unwrapOpenAIWSDialBaseError(err)) + return +} + +func isOpenAIWSClientDisconnectError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) { + return true + } + switch coderws.CloseStatus(err) { + case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure: + return true + } + message := strings.ToLower(strings.TrimSpace(err.Error())) + if message == "" { + return false + } + return strings.Contains(message, "failed to read frame header: eof") || + strings.Contains(message, "unexpected eof") || + strings.Contains(message, "use of closed network connection") || + strings.Contains(message, "connection reset by peer") || + strings.Contains(message, "broken pipe") +} + +func classifyOpenAIWSReadFallbackReason(err error) string { + if err == nil { + return "read_event" + } + switch coderws.CloseStatus(err) { + case coderws.StatusPolicyViolation: + return "policy_violation" + case coderws.StatusMessageTooBig: + return "message_too_big" + default: + return "read_event" + } +} + +func sortedKeys(m map[string]any) []string { + if len(m) == 0 { + return nil + } + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + return keys +} + +func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool { + if s == nil { + return nil + } + s.openaiWSPoolOnce.Do(func() { + if s.openaiWSPool == nil { + s.openaiWSPool = newOpenAIWSConnPool(s.cfg) + } + }) + return s.openaiWSPool +} + +func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot { + pool := s.getOpenAIWSConnPool() + if pool == nil { + return OpenAIWSPoolMetricsSnapshot{} + } + return pool.SnapshotMetrics() +} + +type OpenAIWSPerformanceMetricsSnapshot struct { + Pool OpenAIWSPoolMetricsSnapshot `json:"pool"` + Retry OpenAIWSRetryMetricsSnapshot `json:"retry"` + Transport OpenAIWSTransportMetricsSnapshot `json:"transport"` +} + +func (s *OpenAIGatewayService) SnapshotOpenAIWSPerformanceMetrics() OpenAIWSPerformanceMetricsSnapshot { + pool := s.getOpenAIWSConnPool() + snapshot := OpenAIWSPerformanceMetricsSnapshot{ + Retry: s.SnapshotOpenAIWSRetryMetrics(), + } + if pool == nil { + return snapshot + } + snapshot.Pool = pool.SnapshotMetrics() + snapshot.Transport = pool.SnapshotTransportMetrics() + return snapshot +} + +func (s *OpenAIGatewayService) getOpenAIWSStateStore() OpenAIWSStateStore { + if s == nil { + return nil + } + s.openaiWSStateStoreOnce.Do(func() { + if s.openaiWSStateStore == nil { + s.openaiWSStateStore = NewOpenAIWSStateStore(s.cache) + } + }) + return s.openaiWSStateStore +} + +func (s *OpenAIGatewayService) openAIWSResponseStickyTTL() time.Duration { + if s != nil && s.cfg != nil { + seconds := s.cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds + if seconds > 0 { + return time.Duration(seconds) * time.Second + } + } + return time.Hour +} + +func (s *OpenAIGatewayService) openAIWSIngressPreviousResponseRecoveryEnabled() bool { + if s != nil && s.cfg != nil { + return s.cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled + } + return true +} + +func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds) * time.Second + } + return 15 * time.Minute +} + +func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second + } + return 2 * time.Minute +} + +func (s *OpenAIGatewayService) openAIWSEventFlushBatchSize() int { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushBatchSize > 0 { + return s.cfg.Gateway.OpenAIWS.EventFlushBatchSize + } + return openAIWSEventFlushBatchSizeDefault +} + +func (s *OpenAIGatewayService) openAIWSEventFlushInterval() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS >= 0 { + if s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS == 0 { + return 0 + } + return time.Duration(s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS) * time.Millisecond + } + return openAIWSEventFlushIntervalDefault +} + +func (s *OpenAIGatewayService) openAIWSPayloadLogSampleRate() float64 { + if s != nil && s.cfg != nil { + rate := s.cfg.Gateway.OpenAIWS.PayloadLogSampleRate + if rate < 0 { + return 0 + } + if rate > 1 { + return 1 + } + return rate + } + return openAIWSPayloadLogSampleDefault +} + +func (s *OpenAIGatewayService) shouldLogOpenAIWSPayloadSchema(attempt int) bool { + // 首次尝试保留一条完整 payload_schema 便于排障。 + if attempt <= 1 { + return true + } + rate := s.openAIWSPayloadLogSampleRate() + if rate <= 0 { + return false + } + if rate >= 1 { + return true + } + return rand.Float64() < rate +} + +func (s *OpenAIGatewayService) shouldEmitOpenAIWSPayloadSchema(attempt int) bool { + if !s.shouldLogOpenAIWSPayloadSchema(attempt) { + return false + } + return logger.L().Core().Enabled(zap.DebugLevel) +} + +func (s *OpenAIGatewayService) openAIWSDialTimeout() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second + } + return 10 * time.Second +} + +func (s *OpenAIGatewayService) openAIWSAcquireTimeout() time.Duration { + // Acquire 覆盖“连接复用命中/排队/新建连接”三个阶段。 + // 这里不再叠加 write_timeout,避免高并发排队时把 TTFT 长尾拉到分钟级。 + dial := s.openAIWSDialTimeout() + if dial <= 0 { + dial = 10 * time.Second + } + return dial + 2*time.Second +} + +func (s *OpenAIGatewayService) buildOpenAIResponsesWSURL(account *Account) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + var targetURL string + switch account.Type { + case AccountTypeOAuth: + targetURL = chatgptCodexURL + case AccountTypeAPIKey: + baseURL := account.GetOpenAIBaseURL() + if baseURL == "" { + targetURL = openaiPlatformAPIURL + } else { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return "", err + } + targetURL = buildOpenAIResponsesURL(validatedURL) + } + default: + targetURL = openaiPlatformAPIURL + } + + parsed, err := url.Parse(strings.TrimSpace(targetURL)) + if err != nil { + return "", fmt.Errorf("invalid target url: %w", err) + } + switch strings.ToLower(parsed.Scheme) { + case "https": + parsed.Scheme = "wss" + case "http": + parsed.Scheme = "ws" + case "wss", "ws": + // 保持不变 + default: + return "", fmt.Errorf("unsupported scheme for ws: %s", parsed.Scheme) + } + return parsed.String(), nil +} + +func (s *OpenAIGatewayService) buildOpenAIWSHeaders( + c *gin.Context, + account *Account, + token string, + decision OpenAIWSProtocolDecision, + isCodexCLI bool, + turnState string, + turnMetadata string, + promptCacheKey string, +) (http.Header, openAIWSSessionHeaderResolution) { + headers := make(http.Header) + headers.Set("authorization", "Bearer "+token) + + sessionResolution := resolveOpenAIWSSessionHeaders(c, promptCacheKey) + if c != nil && c.Request != nil { + if v := strings.TrimSpace(c.Request.Header.Get("accept-language")); v != "" { + headers.Set("accept-language", v) + } + } + if sessionResolution.SessionID != "" { + headers.Set("session_id", sessionResolution.SessionID) + } + if sessionResolution.ConversationID != "" { + headers.Set("conversation_id", sessionResolution.ConversationID) + } + if state := strings.TrimSpace(turnState); state != "" { + headers.Set(openAIWSTurnStateHeader, state) + } + if metadata := strings.TrimSpace(turnMetadata); metadata != "" { + headers.Set(openAIWSTurnMetadataHeader, metadata) + } + + if account != nil && account.Type == AccountTypeOAuth { + if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { + headers.Set("chatgpt-account-id", chatgptAccountID) + } + if isCodexCLI { + headers.Set("originator", "codex_cli_rs") + } else { + headers.Set("originator", "opencode") + } + } + + betaValue := openAIWSBetaV2Value + if decision.Transport == OpenAIUpstreamTransportResponsesWebsocket { + betaValue = openAIWSBetaV1Value + } + headers.Set("OpenAI-Beta", betaValue) + + customUA := "" + if account != nil { + customUA = account.GetOpenAIUserAgent() + } + if strings.TrimSpace(customUA) != "" { + headers.Set("user-agent", customUA) + } else if c != nil { + if ua := strings.TrimSpace(c.GetHeader("User-Agent")); ua != "" { + headers.Set("user-agent", ua) + } + } + if s != nil && s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { + headers.Set("user-agent", codexCLIUserAgent) + } + if account != nil && account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(headers.Get("user-agent")) { + headers.Set("user-agent", codexCLIUserAgent) + } + + return headers, sessionResolution +} + +func (s *OpenAIGatewayService) buildOpenAIWSCreatePayload(reqBody map[string]any, account *Account) map[string]any { + // OpenAI WS Mode 协议:response.create 字段与 HTTP /responses 基本一致。 + // 保留 stream 字段(与 Codex CLI 一致),仅移除 background。 + payload := make(map[string]any, len(reqBody)+1) + for k, v := range reqBody { + payload[k] = v + } + + delete(payload, "background") + if _, exists := payload["stream"]; !exists { + payload["stream"] = true + } + payload["type"] = "response.create" + + // OAuth 默认保持 store=false,避免误依赖服务端历史。 + if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { + payload["store"] = false + } + return payload +} + +func setOpenAIWSTurnMetadata(payload map[string]any, turnMetadata string) { + if len(payload) == 0 { + return + } + metadata := strings.TrimSpace(turnMetadata) + if metadata == "" { + return + } + + switch existing := payload["client_metadata"].(type) { + case map[string]any: + existing[openAIWSTurnMetadataHeader] = metadata + payload["client_metadata"] = existing + case map[string]string: + next := make(map[string]any, len(existing)+1) + for k, v := range existing { + next[k] = v + } + next[openAIWSTurnMetadataHeader] = metadata + payload["client_metadata"] = next + default: + payload["client_metadata"] = map[string]any{ + openAIWSTurnMetadataHeader: metadata, + } + } +} + +func (s *OpenAIGatewayService) isOpenAIWSStoreRecoveryAllowed(account *Account) bool { + if account != nil && account.IsOpenAIWSAllowStoreRecoveryEnabled() { + return true + } + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.AllowStoreRecovery { + return true + } + return false +} + +func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequest(reqBody map[string]any, account *Account) bool { + if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { + return true + } + if len(reqBody) == 0 { + return false + } + rawStore, ok := reqBody["store"] + if !ok { + return false + } + storeEnabled, ok := rawStore.(bool) + if !ok { + return false + } + return !storeEnabled +} + +func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequestRaw(reqBody []byte, account *Account) bool { + if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { + return true + } + if len(reqBody) == 0 { + return false + } + storeValue := gjson.GetBytes(reqBody, "store") + if !storeValue.Exists() { + return false + } + if storeValue.Type != gjson.True && storeValue.Type != gjson.False { + return false + } + return !storeValue.Bool() +} + +func (s *OpenAIGatewayService) openAIWSStoreDisabledConnMode() string { + if s == nil || s.cfg == nil { + return openAIWSStoreDisabledConnModeStrict + } + mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.OpenAIWS.StoreDisabledConnMode)) + switch mode { + case openAIWSStoreDisabledConnModeStrict, openAIWSStoreDisabledConnModeAdaptive, openAIWSStoreDisabledConnModeOff: + return mode + case "": + // 兼容旧配置:仅配置了布尔开关时按旧语义推导。 + if s.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn { + return openAIWSStoreDisabledConnModeStrict + } + return openAIWSStoreDisabledConnModeOff + default: + return openAIWSStoreDisabledConnModeStrict + } +} + +func shouldForceNewConnOnStoreDisabled(mode, lastFailureReason string) bool { + switch mode { + case openAIWSStoreDisabledConnModeOff: + return false + case openAIWSStoreDisabledConnModeAdaptive: + reason := strings.TrimPrefix(strings.TrimSpace(lastFailureReason), "prewarm_") + switch reason { + case "policy_violation", "message_too_big", "auth_failed", "write_request", "write": + return true + default: + return false + } + default: + return true + } +} + +func dropPreviousResponseIDFromRawPayload(payload []byte) ([]byte, bool, error) { + return dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, sjson.DeleteBytes) +} + +func dropPreviousResponseIDFromRawPayloadWithDeleteFn( + payload []byte, + deleteFn func([]byte, string) ([]byte, error), +) ([]byte, bool, error) { + if len(payload) == 0 { + return payload, false, nil + } + if !gjson.GetBytes(payload, "previous_response_id").Exists() { + return payload, false, nil + } + if deleteFn == nil { + deleteFn = sjson.DeleteBytes + } + + updated := payload + for i := 0; i < openAIWSMaxPrevResponseIDDeletePasses && + gjson.GetBytes(updated, "previous_response_id").Exists(); i++ { + next, err := deleteFn(updated, "previous_response_id") + if err != nil { + return payload, false, err + } + updated = next + } + return updated, !gjson.GetBytes(updated, "previous_response_id").Exists(), nil +} + +func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string) ([]byte, error) { + normalizedPrevID := strings.TrimSpace(previousResponseID) + if len(payload) == 0 || normalizedPrevID == "" { + return payload, nil + } + updated, err := sjson.SetBytes(payload, "previous_response_id", normalizedPrevID) + if err == nil { + return updated, nil + } + + var reqBody map[string]any + if unmarshalErr := json.Unmarshal(payload, &reqBody); unmarshalErr != nil { + return nil, err + } + reqBody["previous_response_id"] = normalizedPrevID + rebuilt, marshalErr := json.Marshal(reqBody) + if marshalErr != nil { + return nil, marshalErr + } + return rebuilt, nil +} + +func shouldInferIngressFunctionCallOutputPreviousResponseID( + storeDisabled bool, + turn int, + hasFunctionCallOutput bool, + currentPreviousResponseID string, + expectedPreviousResponseID string, +) bool { + if !storeDisabled || turn <= 1 || !hasFunctionCallOutput { + return false + } + if strings.TrimSpace(currentPreviousResponseID) != "" { + return false + } + return strings.TrimSpace(expectedPreviousResponseID) != "" +} + +func alignStoreDisabledPreviousResponseID( + payload []byte, + expectedPreviousResponseID string, +) ([]byte, bool, error) { + if len(payload) == 0 { + return payload, false, nil + } + expected := strings.TrimSpace(expectedPreviousResponseID) + if expected == "" { + return payload, false, nil + } + current := openAIWSPayloadStringFromRaw(payload, "previous_response_id") + if current == "" || current == expected { + return payload, false, nil + } + + withoutPrev, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload) + if dropErr != nil { + return payload, false, dropErr + } + if !removed { + return payload, false, nil + } + updated, setErr := setPreviousResponseIDToRawPayload(withoutPrev, expected) + if setErr != nil { + return payload, false, setErr + } + return updated, true, nil +} + +func cloneOpenAIWSPayloadBytes(payload []byte) []byte { + if len(payload) == 0 { + return nil + } + cloned := make([]byte, len(payload)) + copy(cloned, payload) + return cloned +} + +func cloneOpenAIWSRawMessages(items []json.RawMessage) []json.RawMessage { + if items == nil { + return nil + } + cloned := make([]json.RawMessage, 0, len(items)) + for idx := range items { + cloned = append(cloned, json.RawMessage(cloneOpenAIWSPayloadBytes(items[idx]))) + } + return cloned +} + +func normalizeOpenAIWSJSONForCompare(raw []byte) ([]byte, error) { + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 { + return nil, errors.New("json is empty") + } + var decoded any + if err := json.Unmarshal(trimmed, &decoded); err != nil { + return nil, err + } + return json.Marshal(decoded) +} + +func normalizeOpenAIWSJSONForCompareOrRaw(raw []byte) []byte { + normalized, err := normalizeOpenAIWSJSONForCompare(raw) + if err != nil { + return bytes.TrimSpace(raw) + } + return normalized +} + +func normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload []byte) ([]byte, error) { + if len(payload) == 0 { + return nil, errors.New("payload is empty") + } + var decoded map[string]any + if err := json.Unmarshal(payload, &decoded); err != nil { + return nil, err + } + delete(decoded, "input") + delete(decoded, "previous_response_id") + return json.Marshal(decoded) +} + +func openAIWSExtractNormalizedInputSequence(payload []byte) ([]json.RawMessage, bool, error) { + if len(payload) == 0 { + return nil, false, nil + } + inputValue := gjson.GetBytes(payload, "input") + if !inputValue.Exists() { + return nil, false, nil + } + if inputValue.Type == gjson.JSON { + raw := strings.TrimSpace(inputValue.Raw) + if strings.HasPrefix(raw, "[") { + var items []json.RawMessage + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return nil, true, err + } + return items, true, nil + } + return []json.RawMessage{json.RawMessage(raw)}, true, nil + } + if inputValue.Type == gjson.String { + encoded, _ := json.Marshal(inputValue.String()) + return []json.RawMessage{encoded}, true, nil + } + return []json.RawMessage{json.RawMessage(inputValue.Raw)}, true, nil +} + +func openAIWSInputIsPrefixExtended(previousPayload, currentPayload []byte) (bool, error) { + previousItems, previousExists, prevErr := openAIWSExtractNormalizedInputSequence(previousPayload) + if prevErr != nil { + return false, prevErr + } + currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) + if currentErr != nil { + return false, currentErr + } + if !previousExists && !currentExists { + return true, nil + } + if !previousExists { + return len(currentItems) == 0, nil + } + if !currentExists { + return len(previousItems) == 0, nil + } + if len(currentItems) < len(previousItems) { + return false, nil + } + + for idx := range previousItems { + previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(previousItems[idx]) + currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(currentItems[idx]) + if !bytes.Equal(previousNormalized, currentNormalized) { + return false, nil + } + } + return true, nil +} + +func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage) bool { + if len(prefix) == 0 { + return true + } + if len(items) < len(prefix) { + return false + } + for idx := range prefix { + previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(prefix[idx]) + currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(items[idx]) + if !bytes.Equal(previousNormalized, currentNormalized) { + return false + } + } + return true +} + +func buildOpenAIWSReplayInputSequence( + previousFullInput []json.RawMessage, + previousFullInputExists bool, + currentPayload []byte, + hasPreviousResponseID bool, +) ([]json.RawMessage, bool, error) { + currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) + if currentErr != nil { + return nil, false, currentErr + } + if !hasPreviousResponseID { + return cloneOpenAIWSRawMessages(currentItems), currentExists, nil + } + if !previousFullInputExists { + return cloneOpenAIWSRawMessages(currentItems), currentExists, nil + } + if !currentExists || len(currentItems) == 0 { + return cloneOpenAIWSRawMessages(previousFullInput), true, nil + } + if openAIWSRawItemsHasPrefix(currentItems, previousFullInput) { + return cloneOpenAIWSRawMessages(currentItems), true, nil + } + merged := make([]json.RawMessage, 0, len(previousFullInput)+len(currentItems)) + merged = append(merged, cloneOpenAIWSRawMessages(previousFullInput)...) + merged = append(merged, cloneOpenAIWSRawMessages(currentItems)...) + return merged, true, nil +} + +func setOpenAIWSPayloadInputSequence( + payload []byte, + fullInput []json.RawMessage, + fullInputExists bool, +) ([]byte, error) { + if !fullInputExists { + return payload, nil + } + // Preserve [] vs null semantics when input exists but is empty. + inputForMarshal := fullInput + if inputForMarshal == nil { + inputForMarshal = []json.RawMessage{} + } + inputRaw, marshalErr := json.Marshal(inputForMarshal) + if marshalErr != nil { + return nil, marshalErr + } + return sjson.SetRawBytes(payload, "input", inputRaw) +} + +func shouldKeepIngressPreviousResponseID( + previousPayload []byte, + currentPayload []byte, + lastTurnResponseID string, + hasFunctionCallOutput bool, +) (bool, string, error) { + if hasFunctionCallOutput { + return true, "has_function_call_output", nil + } + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) + if currentPreviousResponseID == "" { + return false, "missing_previous_response_id", nil + } + expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID) + if expectedPreviousResponseID == "" { + return false, "missing_last_turn_response_id", nil + } + if currentPreviousResponseID != expectedPreviousResponseID { + return false, "previous_response_id_mismatch", nil + } + if len(previousPayload) == 0 { + return false, "missing_previous_turn_payload", nil + } + + previousComparable, previousComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(previousPayload) + if previousComparableErr != nil { + return false, "non_input_compare_error", previousComparableErr + } + currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload) + if currentComparableErr != nil { + return false, "non_input_compare_error", currentComparableErr + } + if !bytes.Equal(previousComparable, currentComparable) { + return false, "non_input_changed", nil + } + return true, "strict_incremental_ok", nil +} + +type openAIWSIngressPreviousTurnStrictState struct { + nonInputComparable []byte +} + +func buildOpenAIWSIngressPreviousTurnStrictState(payload []byte) (*openAIWSIngressPreviousTurnStrictState, error) { + if len(payload) == 0 { + return nil, nil + } + nonInputComparable, nonInputErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload) + if nonInputErr != nil { + return nil, nonInputErr + } + return &openAIWSIngressPreviousTurnStrictState{ + nonInputComparable: nonInputComparable, + }, nil +} + +func shouldKeepIngressPreviousResponseIDWithStrictState( + previousState *openAIWSIngressPreviousTurnStrictState, + currentPayload []byte, + lastTurnResponseID string, + hasFunctionCallOutput bool, +) (bool, string, error) { + if hasFunctionCallOutput { + return true, "has_function_call_output", nil + } + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) + if currentPreviousResponseID == "" { + return false, "missing_previous_response_id", nil + } + expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID) + if expectedPreviousResponseID == "" { + return false, "missing_last_turn_response_id", nil + } + if currentPreviousResponseID != expectedPreviousResponseID { + return false, "previous_response_id_mismatch", nil + } + if previousState == nil { + return false, "missing_previous_turn_payload", nil + } + + currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload) + if currentComparableErr != nil { + return false, "non_input_compare_error", currentComparableErr + } + if !bytes.Equal(previousState.nonInputComparable, currentComparable) { + return false, "non_input_changed", nil + } + return true, "strict_incremental_ok", nil +} + +func (s *OpenAIGatewayService) forwardOpenAIWSV2( + ctx context.Context, + c *gin.Context, + account *Account, + reqBody map[string]any, + token string, + decision OpenAIWSProtocolDecision, + isCodexCLI bool, + reqStream bool, + originalModel string, + mappedModel string, + startTime time.Time, + attempt int, + lastFailureReason string, +) (*OpenAIForwardResult, error) { + if s == nil || account == nil { + return nil, wrapOpenAIWSFallback("invalid_state", errors.New("service or account is nil")) + } + + wsURL, err := s.buildOpenAIResponsesWSURL(account) + if err != nil { + return nil, wrapOpenAIWSFallback("build_ws_url", err) + } + wsHost := "-" + wsPath := "-" + if parsed, parseErr := url.Parse(wsURL); parseErr == nil && parsed != nil { + if h := strings.TrimSpace(parsed.Host); h != "" { + wsHost = normalizeOpenAIWSLogValue(h) + } + if p := strings.TrimSpace(parsed.Path); p != "" { + wsPath = normalizeOpenAIWSLogValue(p) + } + } + logOpenAIWSModeDebug( + "dial_target account_id=%d account_type=%s ws_host=%s ws_path=%s", + account.ID, + account.Type, + wsHost, + wsPath, + ) + + payload := s.buildOpenAIWSCreatePayload(reqBody, account) + payloadStrategy, removedKeys := applyOpenAIWSRetryPayloadStrategy(payload, attempt) + previousResponseID := openAIWSPayloadString(payload, "previous_response_id") + previousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(previousResponseID) + promptCacheKey := openAIWSPayloadString(payload, "prompt_cache_key") + _, hasTools := payload["tools"] + debugEnabled := isOpenAIWSModeDebugEnabled() + payloadBytes := -1 + resolvePayloadBytes := func() int { + if payloadBytes >= 0 { + return payloadBytes + } + payloadBytes = len(payloadAsJSONBytes(payload)) + return payloadBytes + } + streamValue := "-" + if raw, ok := payload["stream"]; ok { + streamValue = normalizeOpenAIWSLogValue(strings.TrimSpace(fmt.Sprintf("%v", raw))) + } + turnState := "" + turnMetadata := "" + if c != nil && c.Request != nil { + turnState = strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader)) + turnMetadata = strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)) + } + setOpenAIWSTurnMetadata(payload, turnMetadata) + payloadEventType := openAIWSPayloadString(payload, "type") + if payloadEventType == "" { + payloadEventType = "response.create" + } + if s.shouldEmitOpenAIWSPayloadSchema(attempt) { + logOpenAIWSModeInfo( + "[debug] payload_schema account_id=%d attempt=%d event=%s payload_keys=%s payload_bytes=%d payload_key_sizes=%s input_summary=%s stream=%s payload_strategy=%s removed_keys=%s has_previous_response_id=%v has_prompt_cache_key=%v has_tools=%v", + account.ID, + attempt, + payloadEventType, + normalizeOpenAIWSLogValue(strings.Join(sortedKeys(payload), ",")), + resolvePayloadBytes(), + normalizeOpenAIWSLogValue(summarizeOpenAIWSPayloadKeySizes(payload, openAIWSPayloadKeySizeTopN)), + normalizeOpenAIWSLogValue(summarizeOpenAIWSInput(payload["input"])), + streamValue, + normalizeOpenAIWSLogValue(payloadStrategy), + normalizeOpenAIWSLogValue(strings.Join(removedKeys, ",")), + previousResponseID != "", + promptCacheKey != "", + hasTools, + ) + } + + stateStore := s.getOpenAIWSStateStore() + groupID := getOpenAIGroupIDFromContext(c) + sessionHash := s.GenerateSessionHash(c, nil) + if sessionHash == "" { + var legacySessionHash string + sessionHash, legacySessionHash = openAIWSSessionHashesFromID(promptCacheKey) + attachOpenAILegacySessionHashToGin(c, legacySessionHash) + } + if turnState == "" && stateStore != nil && sessionHash != "" { + if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok { + turnState = savedTurnState + } + } + preferredConnID := "" + if stateStore != nil && previousResponseID != "" { + if connID, ok := stateStore.GetResponseConn(previousResponseID); ok { + preferredConnID = connID + } + } + storeDisabled := s.isOpenAIWSStoreDisabledInRequest(reqBody, account) + if stateStore != nil && storeDisabled && previousResponseID == "" && sessionHash != "" { + if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok { + preferredConnID = connID + } + } + storeDisabledConnMode := s.openAIWSStoreDisabledConnMode() + forceNewConnByPolicy := shouldForceNewConnOnStoreDisabled(storeDisabledConnMode, lastFailureReason) + forceNewConn := forceNewConnByPolicy && storeDisabled && previousResponseID == "" && sessionHash != "" && preferredConnID == "" + wsHeaders, sessionResolution := s.buildOpenAIWSHeaders(c, account, token, decision, isCodexCLI, turnState, turnMetadata, promptCacheKey) + logOpenAIWSModeDebug( + "acquire_start account_id=%d account_type=%s transport=%s preferred_conn_id=%s has_previous_response_id=%v session_hash=%s has_turn_state=%v turn_state_len=%d has_turn_metadata=%v turn_metadata_len=%d store_disabled=%v store_disabled_conn_mode=%s retry_last_reason=%s force_new_conn=%v header_user_agent=%s header_openai_beta=%s header_originator=%s header_accept_language=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_prompt_cache_key=%v has_chatgpt_account_id=%v has_authorization=%v has_session_id=%v has_conversation_id=%v proxy_enabled=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(decision.Transport)), + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + previousResponseID != "", + truncateOpenAIWSLogValue(sessionHash, 12), + turnState != "", + len(turnState), + turnMetadata != "", + len(turnMetadata), + storeDisabled, + normalizeOpenAIWSLogValue(storeDisabledConnMode), + truncateOpenAIWSLogValue(lastFailureReason, openAIWSLogValueMaxLen), + forceNewConn, + openAIWSHeaderValueForLog(wsHeaders, "user-agent"), + openAIWSHeaderValueForLog(wsHeaders, "openai-beta"), + openAIWSHeaderValueForLog(wsHeaders, "originator"), + openAIWSHeaderValueForLog(wsHeaders, "accept-language"), + openAIWSHeaderValueForLog(wsHeaders, "session_id"), + openAIWSHeaderValueForLog(wsHeaders, "conversation_id"), + normalizeOpenAIWSLogValue(sessionResolution.SessionSource), + normalizeOpenAIWSLogValue(sessionResolution.ConversationSource), + promptCacheKey != "", + hasOpenAIWSHeader(wsHeaders, "chatgpt-account-id"), + hasOpenAIWSHeader(wsHeaders, "authorization"), + hasOpenAIWSHeader(wsHeaders, "session_id"), + hasOpenAIWSHeader(wsHeaders, "conversation_id"), + account.ProxyID != nil && account.Proxy != nil, + ) + + acquireCtx, acquireCancel := context.WithTimeout(ctx, s.openAIWSAcquireTimeout()) + defer acquireCancel() + + lease, err := s.getOpenAIWSConnPool().Acquire(acquireCtx, openAIWSAcquireRequest{ + Account: account, + WSURL: wsURL, + Headers: wsHeaders, + PreferredConnID: preferredConnID, + ForceNewConn: forceNewConn, + ProxyURL: func() string { + if account.ProxyID != nil && account.Proxy != nil { + return account.Proxy.URL() + } + return "" + }(), + }) + if err != nil { + dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(err) + logOpenAIWSModeInfo( + "acquire_fail account_id=%d account_type=%s transport=%s reason=%s dial_status=%d dial_class=%s dial_close_status=%s dial_close_reason=%s dial_resp_server=%s dial_resp_via=%s dial_resp_cf_ray=%s dial_resp_x_request_id=%s cause=%s preferred_conn_id=%s force_new_conn=%v ws_host=%s ws_path=%s proxy_enabled=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(decision.Transport)), + normalizeOpenAIWSLogValue(classifyOpenAIWSAcquireError(err)), + dialStatus, + dialClass, + dialCloseStatus, + truncateOpenAIWSLogValue(dialCloseReason, openAIWSHeaderValueMaxLen), + dialRespServer, + dialRespVia, + dialRespCFRay, + dialRespReqID, + truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + forceNewConn, + wsHost, + wsPath, + account.ProxyID != nil && account.Proxy != nil, + ) + return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err) + } + defer lease.Release() + connID := strings.TrimSpace(lease.ConnID()) + logOpenAIWSModeDebug( + "connected account_id=%d account_type=%s transport=%s conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d has_previous_response_id=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(decision.Transport)), + connID, + lease.Reused(), + lease.ConnPickDuration().Milliseconds(), + lease.QueueWaitDuration().Milliseconds(), + previousResponseID != "", + ) + if previousResponseID != "" { + logOpenAIWSModeInfo( + "continuation_probe account_id=%d account_type=%s conn_id=%s previous_response_id=%s previous_response_id_kind=%s preferred_conn_id=%s conn_reused=%v store_disabled=%v session_hash=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v", + account.ID, + account.Type, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(previousResponseIDKind), + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + lease.Reused(), + storeDisabled, + truncateOpenAIWSLogValue(sessionHash, 12), + openAIWSHeaderValueForLog(wsHeaders, "session_id"), + openAIWSHeaderValueForLog(wsHeaders, "conversation_id"), + normalizeOpenAIWSLogValue(sessionResolution.SessionSource), + normalizeOpenAIWSLogValue(sessionResolution.ConversationSource), + turnState != "", + len(turnState), + promptCacheKey != "", + ) + } + if c != nil { + SetOpsLatencyMs(c, OpsOpenAIWSConnPickMsKey, lease.ConnPickDuration().Milliseconds()) + SetOpsLatencyMs(c, OpsOpenAIWSQueueWaitMsKey, lease.QueueWaitDuration().Milliseconds()) + c.Set(OpsOpenAIWSConnReusedKey, lease.Reused()) + if connID != "" { + c.Set(OpsOpenAIWSConnIDKey, connID) + } + } + + handshakeTurnState := strings.TrimSpace(lease.HandshakeHeader(openAIWSTurnStateHeader)) + logOpenAIWSModeDebug( + "handshake account_id=%d conn_id=%s has_turn_state=%v turn_state_len=%d", + account.ID, + connID, + handshakeTurnState != "", + len(handshakeTurnState), + ) + if handshakeTurnState != "" { + if stateStore != nil && sessionHash != "" { + stateStore.BindSessionTurnState(groupID, sessionHash, handshakeTurnState, s.openAIWSSessionStickyTTL()) + } + if c != nil { + c.Header(http.CanonicalHeaderKey(openAIWSTurnStateHeader), handshakeTurnState) + } + } + + if err := s.performOpenAIWSGeneratePrewarm( + ctx, + lease, + decision, + payload, + previousResponseID, + reqBody, + account, + stateStore, + groupID, + ); err != nil { + return nil, err + } + + if err := lease.WriteJSONWithContextTimeout(ctx, payload, s.openAIWSWriteTimeout()); err != nil { + lease.MarkBroken() + logOpenAIWSModeInfo( + "write_request_fail account_id=%d conn_id=%s cause=%s payload_bytes=%d", + account.ID, + connID, + truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + resolvePayloadBytes(), + ) + return nil, wrapOpenAIWSFallback("write_request", err) + } + if debugEnabled { + logOpenAIWSModeDebug( + "write_request_sent account_id=%d conn_id=%s stream=%v payload_bytes=%d previous_response_id=%s", + account.ID, + connID, + reqStream, + resolvePayloadBytes(), + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + ) + } + + usage := &OpenAIUsage{} + var firstTokenMs *int + responseID := "" + var finalResponse []byte + wroteDownstream := false + needModelReplace := originalModel != mappedModel + var mappedModelBytes []byte + if needModelReplace && mappedModel != "" { + mappedModelBytes = []byte(mappedModel) + } + bufferedStreamEvents := make([][]byte, 0, 4) + eventCount := 0 + tokenEventCount := 0 + terminalEventCount := 0 + bufferedEventCount := 0 + flushedBufferedEventCount := 0 + firstEventType := "" + lastEventType := "" + + var flusher http.Flusher + if reqStream { + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), http.Header{}, s.responseHeaderFilter) + } + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + f, ok := c.Writer.(http.Flusher) + if !ok { + lease.MarkBroken() + return nil, wrapOpenAIWSFallback("streaming_not_supported", errors.New("streaming not supported")) + } + flusher = f + } + + clientDisconnected := false + flushBatchSize := s.openAIWSEventFlushBatchSize() + flushInterval := s.openAIWSEventFlushInterval() + pendingFlushEvents := 0 + lastFlushAt := time.Now() + flushStreamWriter := func(force bool) { + if clientDisconnected || flusher == nil || pendingFlushEvents <= 0 { + return + } + if !force && flushBatchSize > 1 && pendingFlushEvents < flushBatchSize { + if flushInterval <= 0 || time.Since(lastFlushAt) < flushInterval { + return + } + } + flusher.Flush() + pendingFlushEvents = 0 + lastFlushAt = time.Now() + } + emitStreamMessage := func(message []byte, forceFlush bool) { + if clientDisconnected { + return + } + frame := make([]byte, 0, len(message)+8) + frame = append(frame, "data: "...) + frame = append(frame, message...) + frame = append(frame, '\n', '\n') + _, wErr := c.Writer.Write(frame) + if wErr == nil { + wroteDownstream = true + pendingFlushEvents++ + flushStreamWriter(forceFlush) + return + } + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode] client disconnected, continue draining upstream: account=%d", account.ID) + } + flushBufferedStreamEvents := func(reason string) { + if len(bufferedStreamEvents) == 0 { + return + } + flushed := len(bufferedStreamEvents) + for _, buffered := range bufferedStreamEvents { + emitStreamMessage(buffered, false) + } + bufferedStreamEvents = bufferedStreamEvents[:0] + flushStreamWriter(true) + flushedBufferedEventCount += flushed + if debugEnabled { + logOpenAIWSModeDebug( + "buffer_flush account_id=%d conn_id=%s reason=%s flushed=%d total_flushed=%d client_disconnected=%v", + account.ID, + connID, + truncateOpenAIWSLogValue(reason, openAIWSLogValueMaxLen), + flushed, + flushedBufferedEventCount, + clientDisconnected, + ) + } + } + + readTimeout := s.openAIWSReadTimeout() + + for { + message, readErr := lease.ReadMessageWithContextTimeout(ctx, readTimeout) + if readErr != nil { + lease.MarkBroken() + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "read_fail account_id=%d conn_id=%s wrote_downstream=%v close_status=%s close_reason=%s cause=%s events=%d token_events=%d terminal_events=%d buffered_pending=%d buffered_flushed=%d first_event=%s last_event=%s", + account.ID, + connID, + wroteDownstream, + closeStatus, + closeReason, + truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen), + eventCount, + tokenEventCount, + terminalEventCount, + len(bufferedStreamEvents), + flushedBufferedEventCount, + truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen), + ) + if !wroteDownstream { + return nil, wrapOpenAIWSFallback(classifyOpenAIWSReadFallbackReason(readErr), readErr) + } + if clientDisconnected { + break + } + setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(readErr.Error()), "") + return nil, fmt.Errorf("openai ws read event: %w", readErr) + } + + eventType, eventResponseID, responseField := parseOpenAIWSEventEnvelope(message) + if eventType == "" { + continue + } + eventCount++ + if firstEventType == "" { + firstEventType = eventType + } + lastEventType = eventType + + if responseID == "" && eventResponseID != "" { + responseID = eventResponseID + } + + isTokenEvent := isOpenAIWSTokenEvent(eventType) + if isTokenEvent { + tokenEventCount++ + } + isTerminalEvent := isOpenAIWSTerminalEvent(eventType) + if isTerminalEvent { + terminalEventCount++ + } + if firstTokenMs == nil && isTokenEvent { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + if debugEnabled && shouldLogOpenAIWSEvent(eventCount, eventType) { + logOpenAIWSModeDebug( + "event_received account_id=%d conn_id=%s idx=%d type=%s bytes=%d token=%v terminal=%v buffered_pending=%d", + account.ID, + connID, + eventCount, + truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), + len(message), + isTokenEvent, + isTerminalEvent, + len(bufferedStreamEvents), + ) + } + + if !clientDisconnected { + if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(message, mappedModelBytes) { + message = replaceOpenAIWSMessageModel(message, mappedModel, originalModel) + } + if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(message) { + if corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(message); changed { + message = corrected + } + } + } + if openAIWSEventShouldParseUsage(eventType) { + parseOpenAIWSResponseUsageFromCompletedEvent(message, usage) + } + + if eventType == "error" { + errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + errMsg := strings.TrimSpace(errMsgRaw) + if errMsg == "" { + errMsg = "Upstream websocket error" + } + fallbackReason, canFallback := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + logOpenAIWSModeInfo( + "error_event account_id=%d conn_id=%s idx=%d fallback_reason=%s can_fallback=%v err_code=%s err_type=%s err_message=%s", + account.ID, + connID, + eventCount, + truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), + canFallback, + errCode, + errType, + errMessage, + ) + if fallbackReason == "previous_response_not_found" { + logOpenAIWSModeInfo( + "previous_response_not_found_diag account_id=%d account_type=%s conn_id=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s event_idx=%d req_stream=%v store_disabled=%v conn_reused=%v session_hash=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v err_code=%s err_type=%s err_message=%s", + account.ID, + account.Type, + connID, + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(previousResponseIDKind), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + eventCount, + reqStream, + storeDisabled, + lease.Reused(), + truncateOpenAIWSLogValue(sessionHash, 12), + openAIWSHeaderValueForLog(wsHeaders, "session_id"), + openAIWSHeaderValueForLog(wsHeaders, "conversation_id"), + normalizeOpenAIWSLogValue(sessionResolution.SessionSource), + normalizeOpenAIWSLogValue(sessionResolution.ConversationSource), + turnState != "", + len(turnState), + promptCacheKey != "", + errCode, + errType, + errMessage, + ) + } + // error 事件后连接不再可复用,避免回池后污染下一请求。 + lease.MarkBroken() + if !wroteDownstream && canFallback { + return nil, wrapOpenAIWSFallback(fallbackReason, errors.New(errMsg)) + } + statusCode := openAIWSErrorHTTPStatusFromRaw(errCodeRaw, errTypeRaw) + setOpsUpstreamError(c, statusCode, errMsg, "") + if reqStream && !clientDisconnected { + flushBufferedStreamEvents("error_event") + emitStreamMessage(message, true) + } + if !reqStream { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": errMsg, + }, + }) + } + return nil, fmt.Errorf("openai ws error event: %s", errMsg) + } + + if reqStream { + // 在首个 token 前先缓冲事件(如 response.created), + // 以便上游早期断连时仍可安全回退到 HTTP,不给下游发送半截流。 + shouldBuffer := firstTokenMs == nil && !isTokenEvent && !isTerminalEvent + if shouldBuffer { + buffered := make([]byte, len(message)) + copy(buffered, message) + bufferedStreamEvents = append(bufferedStreamEvents, buffered) + bufferedEventCount++ + if debugEnabled && shouldLogOpenAIWSBufferedEvent(bufferedEventCount) { + logOpenAIWSModeDebug( + "buffer_enqueue account_id=%d conn_id=%s idx=%d event_idx=%d event_type=%s buffer_size=%d", + account.ID, + connID, + bufferedEventCount, + eventCount, + truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), + len(bufferedStreamEvents), + ) + } + } else { + flushBufferedStreamEvents(eventType) + emitStreamMessage(message, isTerminalEvent) + } + } else { + if responseField.Exists() && responseField.Type == gjson.JSON { + finalResponse = []byte(responseField.Raw) + } + } + + if isTerminalEvent { + break + } + } + + if !reqStream { + if len(finalResponse) == 0 { + logOpenAIWSModeInfo( + "missing_final_response account_id=%d conn_id=%s events=%d token_events=%d terminal_events=%d wrote_downstream=%v", + account.ID, + connID, + eventCount, + tokenEventCount, + terminalEventCount, + wroteDownstream, + ) + if !wroteDownstream { + return nil, wrapOpenAIWSFallback("missing_final_response", errors.New("no terminal response payload")) + } + return nil, errors.New("ws finished without final response") + } + + if needModelReplace { + finalResponse = s.replaceModelInResponseBody(finalResponse, mappedModel, originalModel) + } + finalResponse = s.correctToolCallsInResponseBody(finalResponse) + populateOpenAIUsageFromResponseJSON(finalResponse, usage) + if responseID == "" { + responseID = strings.TrimSpace(gjson.GetBytes(finalResponse, "id").String()) + } + + c.Data(http.StatusOK, "application/json", finalResponse) + } else { + flushStreamWriter(true) + } + + if responseID != "" && stateStore != nil { + ttl := s.openAIWSResponseStickyTTL() + logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl)) + stateStore.BindResponseConn(responseID, lease.ConnID(), ttl) + } + if stateStore != nil && storeDisabled && sessionHash != "" { + stateStore.BindSessionConn(groupID, sessionHash, lease.ConnID(), s.openAIWSSessionStickyTTL()) + } + firstTokenMsValue := -1 + if firstTokenMs != nil { + firstTokenMsValue = *firstTokenMs + } + logOpenAIWSModeDebug( + "completed account_id=%d conn_id=%s response_id=%s stream=%v duration_ms=%d events=%d token_events=%d terminal_events=%d buffered_events=%d buffered_flushed=%d first_event=%s last_event=%s first_token_ms=%d wrote_downstream=%v client_disconnected=%v", + account.ID, + connID, + truncateOpenAIWSLogValue(strings.TrimSpace(responseID), openAIWSIDValueMaxLen), + reqStream, + time.Since(startTime).Milliseconds(), + eventCount, + tokenEventCount, + terminalEventCount, + bufferedEventCount, + flushedBufferedEventCount, + truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen), + firstTokenMsValue, + wroteDownstream, + clientDisconnected, + ) + + return &OpenAIForwardResult{ + RequestID: responseID, + Usage: *usage, + Model: originalModel, + ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel), + Stream: reqStream, + OpenAIWSMode: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + +// ProxyResponsesWebSocketFromClient 处理客户端入站 WebSocket(OpenAI Responses WS Mode)并转发到上游。 +// 当前实现按“单请求 -> 终止事件 -> 下一请求”的顺序代理,适配 Codex CLI 的 turn 模式。 +func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( + ctx context.Context, + c *gin.Context, + clientConn *coderws.Conn, + account *Account, + token string, + firstClientMessage []byte, + hooks *OpenAIWSIngressHooks, +) error { + if s == nil { + return errors.New("service is nil") + } + if c == nil { + return errors.New("gin context is nil") + } + if clientConn == nil { + return errors.New("client websocket is nil") + } + if account == nil { + return errors.New("account is nil") + } + if strings.TrimSpace(token) == "" { + return errors.New("token is empty") + } + + wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) + modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled + ingressMode := OpenAIWSIngressModeShared + if modeRouterV2Enabled { + ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault) + if ingressMode == OpenAIWSIngressModeOff { + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "websocket mode is disabled for this account", + nil, + ) + } + } + if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport) + } + dedicatedMode := modeRouterV2Enabled && ingressMode == OpenAIWSIngressModeDedicated + + wsURL, err := s.buildOpenAIResponsesWSURL(account) + if err != nil { + return fmt.Errorf("build ws url: %w", err) + } + wsHost := "-" + wsPath := "-" + if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil { + wsHost = normalizeOpenAIWSLogValue(parsedURL.Host) + wsPath = normalizeOpenAIWSLogValue(parsedURL.Path) + } + debugEnabled := isOpenAIWSModeDebugEnabled() + + type openAIWSClientPayload struct { + payloadRaw []byte + rawForHash []byte + promptCacheKey string + previousResponseID string + originalModel string + payloadBytes int + } + + applyPayloadMutation := func(current []byte, path string, value any) ([]byte, error) { + next, err := sjson.SetBytes(current, path, value) + if err == nil { + return next, nil + } + + // 仅在确实需要修改 payload 且 sjson 失败时,退回 map 路径确保兼容性。 + payload := make(map[string]any) + if unmarshalErr := json.Unmarshal(current, &payload); unmarshalErr != nil { + return nil, err + } + switch path { + case "type", "model": + payload[path] = value + case "client_metadata." + openAIWSTurnMetadataHeader: + setOpenAIWSTurnMetadata(payload, fmt.Sprintf("%v", value)) + default: + return nil, err + } + rebuilt, marshalErr := json.Marshal(payload) + if marshalErr != nil { + return nil, marshalErr + } + return rebuilt, nil + } + + parseClientPayload := func(raw []byte) (openAIWSClientPayload, error) { + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "empty websocket request payload", nil) + } + if !gjson.ValidBytes(trimmed) { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json")) + } + + values := gjson.GetManyBytes(trimmed, "type", "model", "prompt_cache_key", "previous_response_id") + eventType := strings.TrimSpace(values[0].String()) + normalized := trimmed + switch eventType { + case "": + eventType = "response.create" + next, setErr := applyPayloadMutation(normalized, "type", eventType) + if setErr != nil { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) + } + normalized = next + case "response.create": + case "response.append": + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "response.append is not supported in ws v2; use response.create with previous_response_id", + nil, + ) + default: + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + fmt.Sprintf("unsupported websocket request type: %s", eventType), + nil, + ) + } + + originalModel := strings.TrimSpace(values[1].String()) + if originalModel == "" { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "model is required in response.create payload", + nil, + ) + } + promptCacheKey := strings.TrimSpace(values[2].String()) + previousResponseID := strings.TrimSpace(values[3].String()) + previousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(previousResponseID) + if previousResponseID != "" && previousResponseIDKind == OpenAIPreviousResponseIDKindMessageID { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "previous_response_id must be a response.id (resp_*), not a message id", + nil, + ) + } + if turnMetadata := strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)); turnMetadata != "" { + next, setErr := applyPayloadMutation(normalized, "client_metadata."+openAIWSTurnMetadataHeader, turnMetadata) + if setErr != nil { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) + } + normalized = next + } + mappedModel := account.GetMappedModel(originalModel) + if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { + mappedModel = normalizedModel + } + if mappedModel != originalModel { + next, setErr := applyPayloadMutation(normalized, "model", mappedModel) + if setErr != nil { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) + } + normalized = next + } + + return openAIWSClientPayload{ + payloadRaw: normalized, + rawForHash: trimmed, + promptCacheKey: promptCacheKey, + previousResponseID: previousResponseID, + originalModel: originalModel, + payloadBytes: len(normalized), + }, nil + } + + firstPayload, err := parseClientPayload(firstClientMessage) + if err != nil { + return err + } + + turnState := strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader)) + stateStore := s.getOpenAIWSStateStore() + groupID := getOpenAIGroupIDFromContext(c) + sessionHash := s.GenerateSessionHash(c, firstPayload.rawForHash) + if turnState == "" && stateStore != nil && sessionHash != "" { + if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok { + turnState = savedTurnState + } + } + + preferredConnID := "" + if stateStore != nil && firstPayload.previousResponseID != "" { + if connID, ok := stateStore.GetResponseConn(firstPayload.previousResponseID); ok { + preferredConnID = connID + } + } + + storeDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(firstPayload.payloadRaw, account) + storeDisabledConnMode := s.openAIWSStoreDisabledConnMode() + if stateStore != nil && storeDisabled && firstPayload.previousResponseID == "" && sessionHash != "" { + if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok { + preferredConnID = connID + } + } + + isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) + wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey) + baseAcquireReq := openAIWSAcquireRequest{ + Account: account, + WSURL: wsURL, + Headers: wsHeaders, + ProxyURL: func() string { + if account.ProxyID != nil && account.Proxy != nil { + return account.Proxy.URL() + } + return "" + }(), + ForceNewConn: false, + } + pool := s.getOpenAIWSConnPool() + if pool == nil { + return errors.New("openai ws conn pool is nil") + } + + logOpenAIWSModeInfo( + "ingress_ws_protocol_confirm account_id=%d account_type=%s transport=%s ws_host=%s ws_path=%s ws_mode=%s store_disabled=%v has_session_hash=%v has_previous_response_id=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(wsDecision.Transport)), + wsHost, + wsPath, + normalizeOpenAIWSLogValue(ingressMode), + storeDisabled, + sessionHash != "", + firstPayload.previousResponseID != "", + ) + + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_start account_id=%d account_type=%s transport=%s ws_host=%s preferred_conn_id=%s has_session_hash=%v has_previous_response_id=%v store_disabled=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(wsDecision.Transport)), + wsHost, + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + sessionHash != "", + firstPayload.previousResponseID != "", + storeDisabled, + ) + } + if firstPayload.previousResponseID != "" { + firstPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(firstPayload.previousResponseID) + logOpenAIWSModeInfo( + "ingress_ws_continuation_probe account_id=%d turn=%d previous_response_id=%s previous_response_id_kind=%s preferred_conn_id=%s session_hash=%s header_session_id=%s header_conversation_id=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v store_disabled=%v", + account.ID, + 1, + truncateOpenAIWSLogValue(firstPayload.previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(firstPreviousResponseIDKind), + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(sessionHash, 12), + openAIWSHeaderValueForLog(baseAcquireReq.Headers, "session_id"), + openAIWSHeaderValueForLog(baseAcquireReq.Headers, "conversation_id"), + turnState != "", + len(turnState), + firstPayload.promptCacheKey != "", + storeDisabled, + ) + } + + acquireTimeout := s.openAIWSAcquireTimeout() + if acquireTimeout <= 0 { + acquireTimeout = 30 * time.Second + } + + acquireTurnLease := func(turn int, preferred string, forcePreferredConn bool) (*openAIWSConnLease, error) { + req := cloneOpenAIWSAcquireRequest(baseAcquireReq) + req.PreferredConnID = strings.TrimSpace(preferred) + req.ForcePreferredConn = forcePreferredConn + // dedicated 模式下每次获取均新建连接,避免跨会话复用残留上下文。 + req.ForceNewConn = dedicatedMode + acquireCtx, acquireCancel := context.WithTimeout(ctx, acquireTimeout) + lease, acquireErr := pool.Acquire(acquireCtx, req) + acquireCancel() + if acquireErr != nil { + dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(acquireErr) + logOpenAIWSModeInfo( + "ingress_ws_upstream_acquire_fail account_id=%d turn=%d reason=%s dial_status=%d dial_class=%s dial_close_status=%s dial_close_reason=%s dial_resp_server=%s dial_resp_via=%s dial_resp_cf_ray=%s dial_resp_x_request_id=%s cause=%s preferred_conn_id=%s force_preferred_conn=%v ws_host=%s ws_path=%s proxy_enabled=%v", + account.ID, + turn, + normalizeOpenAIWSLogValue(classifyOpenAIWSAcquireError(acquireErr)), + dialStatus, + dialClass, + dialCloseStatus, + truncateOpenAIWSLogValue(dialCloseReason, openAIWSHeaderValueMaxLen), + dialRespServer, + dialRespVia, + dialRespCFRay, + dialRespReqID, + truncateOpenAIWSLogValue(acquireErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(preferred, openAIWSIDValueMaxLen), + forcePreferredConn, + wsHost, + wsPath, + account.ProxyID != nil && account.Proxy != nil, + ) + if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) { + return nil, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "upstream continuation connection is unavailable; please restart the conversation", + acquireErr, + ) + } + if errors.Is(acquireErr, context.DeadlineExceeded) || errors.Is(acquireErr, errOpenAIWSConnQueueFull) { + return nil, NewOpenAIWSClientCloseError( + coderws.StatusTryAgainLater, + "upstream websocket is busy, please retry later", + acquireErr, + ) + } + return nil, acquireErr + } + connID := strings.TrimSpace(lease.ConnID()) + if handshakeTurnState := strings.TrimSpace(lease.HandshakeHeader(openAIWSTurnStateHeader)); handshakeTurnState != "" { + turnState = handshakeTurnState + if stateStore != nil && sessionHash != "" { + stateStore.BindSessionTurnState(groupID, sessionHash, handshakeTurnState, s.openAIWSSessionStickyTTL()) + } + updatedHeaders := cloneHeader(baseAcquireReq.Headers) + if updatedHeaders == nil { + updatedHeaders = make(http.Header) + } + updatedHeaders.Set(openAIWSTurnStateHeader, handshakeTurnState) + baseAcquireReq.Headers = updatedHeaders + } + logOpenAIWSModeInfo( + "ingress_ws_upstream_connected account_id=%d turn=%d conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d preferred_conn_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + lease.Reused(), + lease.ConnPickDuration().Milliseconds(), + lease.QueueWaitDuration().Milliseconds(), + truncateOpenAIWSLogValue(preferred, openAIWSIDValueMaxLen), + ) + return lease, nil + } + + writeClientMessage := func(message []byte) error { + writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout()) + defer cancel() + return clientConn.Write(writeCtx, coderws.MessageText, message) + } + + readClientMessage := func() ([]byte, error) { + msgType, payload, readErr := clientConn.Read(ctx) + if readErr != nil { + return nil, readErr + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + return nil, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + fmt.Sprintf("unsupported websocket client message type: %s", msgType.String()), + nil, + ) + } + return payload, nil + } + + sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string) (*OpenAIForwardResult, error) { + if lease == nil { + return nil, errors.New("upstream websocket lease is nil") + } + turnStart := time.Now() + wroteDownstream := false + if err := lease.WriteJSONWithContextTimeout(ctx, json.RawMessage(payload), s.openAIWSWriteTimeout()); err != nil { + return nil, wrapOpenAIWSIngressTurnError( + "write_upstream", + fmt.Errorf("write upstream websocket request: %w", err), + false, + ) + } + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_turn_request_sent account_id=%d turn=%d conn_id=%s payload_bytes=%d", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + payloadBytes, + ) + } + + responseID := "" + usage := OpenAIUsage{} + var firstTokenMs *int + reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true) + turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id") + turnPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(turnPreviousResponseID) + turnPromptCacheKey := openAIWSPayloadStringFromRaw(payload, "prompt_cache_key") + turnStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(payload, account) + turnHasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() + eventCount := 0 + tokenEventCount := 0 + terminalEventCount := 0 + firstEventType := "" + lastEventType := "" + needModelReplace := false + clientDisconnected := false + mappedModel := "" + var mappedModelBytes []byte + if originalModel != "" { + mappedModel = account.GetMappedModel(originalModel) + if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { + mappedModel = normalizedModel + } + needModelReplace = mappedModel != "" && mappedModel != originalModel + if needModelReplace { + mappedModelBytes = []byte(mappedModel) + } + } + for { + upstreamMessage, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout()) + if readErr != nil { + lease.MarkBroken() + return nil, wrapOpenAIWSIngressTurnError( + "read_upstream", + fmt.Errorf("read upstream websocket event: %w", readErr), + wroteDownstream, + ) + } + + eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(upstreamMessage) + if responseID == "" && eventResponseID != "" { + responseID = eventResponseID + } + if eventType != "" { + eventCount++ + if firstEventType == "" { + firstEventType = eventType + } + lastEventType = eventType + } + if eventType == "error" { + errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage) + fallbackReason, _ := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + recoverablePrevNotFound := fallbackReason == openAIWSIngressStagePreviousResponseNotFound && + turnPreviousResponseID != "" && + !turnHasFunctionCallOutput && + s.openAIWSIngressPreviousResponseRecoveryEnabled() && + !wroteDownstream + if recoverablePrevNotFound { + // 可恢复场景使用非 error 关键字日志,避免被 LegacyPrintf 误判为 ERROR 级别。 + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recoverable account_id=%d turn=%d conn_id=%s idx=%d reason=%s code=%s type=%s message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + eventCount, + truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), + errCode, + errType, + errMessage, + truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(turnPreviousResponseIDKind), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + turnStoreDisabled, + turnPromptCacheKey != "", + ) + } else { + logOpenAIWSModeInfo( + "ingress_ws_error_event account_id=%d turn=%d conn_id=%s idx=%d fallback_reason=%s err_code=%s err_type=%s err_message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + eventCount, + truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), + errCode, + errType, + errMessage, + truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(turnPreviousResponseIDKind), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + turnStoreDisabled, + turnPromptCacheKey != "", + ) + } + // previous_response_not_found 在 ingress 模式支持单次恢复重试: + // 不把该 error 直接下发客户端,而是由上层去掉 previous_response_id 后重放当前 turn。 + if recoverablePrevNotFound { + lease.MarkBroken() + errMsg := strings.TrimSpace(errMsgRaw) + if errMsg == "" { + errMsg = "previous response not found" + } + return nil, wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New(errMsg), + false, + ) + } + } + isTokenEvent := isOpenAIWSTokenEvent(eventType) + if isTokenEvent { + tokenEventCount++ + } + isTerminalEvent := isOpenAIWSTerminalEvent(eventType) + if isTerminalEvent { + terminalEventCount++ + } + if firstTokenMs == nil && isTokenEvent { + ms := int(time.Since(turnStart).Milliseconds()) + firstTokenMs = &ms + } + if openAIWSEventShouldParseUsage(eventType) { + parseOpenAIWSResponseUsageFromCompletedEvent(upstreamMessage, &usage) + } + + if !clientDisconnected { + if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(upstreamMessage, mappedModelBytes) { + upstreamMessage = replaceOpenAIWSMessageModel(upstreamMessage, mappedModel, originalModel) + } + if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(upstreamMessage) { + if corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(upstreamMessage); changed { + upstreamMessage = corrected + } + } + if err := writeClientMessage(upstreamMessage); err != nil { + if isOpenAIWSClientDisconnectError(err) { + clientDisconnected = true + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(err) + logOpenAIWSModeInfo( + "ingress_ws_client_disconnected_drain account_id=%d turn=%d conn_id=%s close_status=%s close_reason=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + closeStatus, + truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen), + ) + } else { + return nil, wrapOpenAIWSIngressTurnError( + "write_client", + fmt.Errorf("write client websocket event: %w", err), + wroteDownstream, + ) + } + } else { + wroteDownstream = true + } + } + if isTerminalEvent { + // 客户端已断连时,上游连接的 session 状态不可信,标记 broken 避免回池复用。 + if clientDisconnected { + lease.MarkBroken() + } + firstTokenMsValue := -1 + if firstTokenMs != nil { + firstTokenMsValue = *firstTokenMs + } + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_turn_completed account_id=%d turn=%d conn_id=%s response_id=%s duration_ms=%d events=%d token_events=%d terminal_events=%d first_event=%s last_event=%s first_token_ms=%d client_disconnected=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + time.Since(turnStart).Milliseconds(), + eventCount, + tokenEventCount, + terminalEventCount, + truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen), + firstTokenMsValue, + clientDisconnected, + ) + } + return &OpenAIForwardResult{ + RequestID: responseID, + Usage: usage, + Model: originalModel, + ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel), + Stream: reqStream, + OpenAIWSMode: true, + Duration: time.Since(turnStart), + FirstTokenMs: firstTokenMs, + }, nil + } + } + } + + currentPayload := firstPayload.payloadRaw + currentOriginalModel := firstPayload.originalModel + currentPayloadBytes := firstPayload.payloadBytes + isStrictAffinityTurn := func(payload []byte) bool { + if !storeDisabled { + return false + } + return strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id")) != "" + } + var sessionLease *openAIWSConnLease + sessionConnID := "" + pinnedSessionConnID := "" + unpinSessionConn := func(connID string) { + connID = strings.TrimSpace(connID) + if connID == "" || pinnedSessionConnID != connID { + return + } + pool.UnpinConn(account.ID, connID) + pinnedSessionConnID = "" + } + pinSessionConn := func(connID string) { + if !storeDisabled { + return + } + connID = strings.TrimSpace(connID) + if connID == "" || pinnedSessionConnID == connID { + return + } + if pinnedSessionConnID != "" { + pool.UnpinConn(account.ID, pinnedSessionConnID) + pinnedSessionConnID = "" + } + if pool.PinConn(account.ID, connID) { + pinnedSessionConnID = connID + } + } + releaseSessionLease := func() { + if sessionLease == nil { + return + } + if dedicatedMode { + // dedicated 会话结束后主动标记损坏,确保连接不会跨会话复用。 + sessionLease.MarkBroken() + } + unpinSessionConn(sessionConnID) + sessionLease.Release() + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_upstream_released account_id=%d conn_id=%s", + account.ID, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + ) + } + } + defer releaseSessionLease() + + turn := 1 + turnRetry := 0 + turnPrevRecoveryTried := false + lastTurnFinishedAt := time.Time{} + lastTurnResponseID := "" + lastTurnPayload := []byte(nil) + var lastTurnStrictState *openAIWSIngressPreviousTurnStrictState + lastTurnReplayInput := []json.RawMessage(nil) + lastTurnReplayInputExists := false + currentTurnReplayInput := []json.RawMessage(nil) + currentTurnReplayInputExists := false + skipBeforeTurn := false + resetSessionLease := func(markBroken bool) { + if sessionLease == nil { + return + } + if markBroken { + sessionLease.MarkBroken() + } + releaseSessionLease() + sessionLease = nil + sessionConnID = "" + preferredConnID = "" + } + recoverIngressPrevResponseNotFound := func(relayErr error, turn int, connID string) bool { + if !isOpenAIWSIngressPreviousResponseNotFound(relayErr) { + return false + } + if turnPrevRecoveryTried || !s.openAIWSIngressPreviousResponseRecoveryEnabled() { + return false + } + if isStrictAffinityTurn(currentPayload) { + // Layer 2:严格亲和链路命中 previous_response_not_found 时,降级为“去掉 previous_response_id 后重放一次”。 + // 该错误说明续链锚点已失效,继续 strict fail-close 只会直接中断本轮请求。 + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_layer2 account_id=%d turn=%d conn_id=%s store_disabled_conn_mode=%s action=drop_previous_response_id_retry", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(storeDisabledConnMode), + ) + } + turnPrevRecoveryTried = true + updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) + if dropErr != nil || !removed { + reason := "not_removed" + if dropErr != nil { + reason = "drop_error" + } + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(reason), + ) + return false + } + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + currentTurnReplayInput, + currentTurnReplayInputExists, + ) + if setInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), + ) + return false + } + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id retry=1", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + currentPayload = updatedWithInput + currentPayloadBytes = len(updatedWithInput) + resetSessionLease(true) + skipBeforeTurn = true + return true + } + retryIngressTurn := func(relayErr error, turn int, connID string) bool { + if !isOpenAIWSIngressTurnRetryable(relayErr) || turnRetry >= 1 { + return false + } + if isStrictAffinityTurn(currentPayload) { + logOpenAIWSModeInfo( + "ingress_ws_turn_retry_skip account_id=%d turn=%d conn_id=%s reason=strict_affinity", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + return false + } + turnRetry++ + logOpenAIWSModeInfo( + "ingress_ws_turn_retry account_id=%d turn=%d retry=%d reason=%s conn_id=%s", + account.ID, + turn, + turnRetry, + truncateOpenAIWSLogValue(openAIWSIngressTurnRetryReason(relayErr), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + resetSessionLease(true) + skipBeforeTurn = true + return true + } + for { + if !skipBeforeTurn && hooks != nil && hooks.BeforeTurn != nil { + if err := hooks.BeforeTurn(turn); err != nil { + return err + } + } + skipBeforeTurn = false + currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id") + expectedPrev := strings.TrimSpace(lastTurnResponseID) + hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() + // store=false + function_call_output 场景必须有续链锚点。 + // 若客户端未传 previous_response_id,优先回填上一轮响应 ID,避免上游报 call_id 无法关联。 + if shouldInferIngressFunctionCallOutputPreviousResponseID( + storeDisabled, + turn, + hasFunctionCallOutput, + currentPreviousResponseID, + expectedPrev, + ) { + updatedPayload, setPrevErr := setPreviousResponseIDToRawPayload(currentPayload, expectedPrev) + if setPrevErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_function_call_output_prev_infer_skip account_id=%d turn=%d conn_id=%s reason=set_previous_response_id_error cause=%s expected_previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setPrevErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + ) + } else { + currentPayload = updatedPayload + currentPayloadBytes = len(updatedPayload) + currentPreviousResponseID = expectedPrev + logOpenAIWSModeInfo( + "ingress_ws_function_call_output_prev_infer account_id=%d turn=%d conn_id=%s action=set_previous_response_id previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + ) + } + } + nextReplayInput, nextReplayInputExists, replayInputErr := buildOpenAIWSReplayInputSequence( + lastTurnReplayInput, + lastTurnReplayInputExists, + currentPayload, + currentPreviousResponseID != "", + ) + if replayInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_replay_input_skip account_id=%d turn=%d conn_id=%s reason=build_error cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(replayInputErr.Error(), openAIWSLogValueMaxLen), + ) + currentTurnReplayInput = nil + currentTurnReplayInputExists = false + } else { + currentTurnReplayInput = nextReplayInput + currentTurnReplayInputExists = nextReplayInputExists + } + if storeDisabled && turn > 1 && currentPreviousResponseID != "" { + shouldKeepPreviousResponseID := false + strictReason := "" + var strictErr error + if lastTurnStrictState != nil { + shouldKeepPreviousResponseID, strictReason, strictErr = shouldKeepIngressPreviousResponseIDWithStrictState( + lastTurnStrictState, + currentPayload, + lastTurnResponseID, + hasFunctionCallOutput, + ) + } else { + shouldKeepPreviousResponseID, strictReason, strictErr = shouldKeepIngressPreviousResponseID( + lastTurnPayload, + currentPayload, + lastTurnResponseID, + hasFunctionCallOutput, + ) + } + if strictErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s cause=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(strictReason), + truncateOpenAIWSLogValue(strictErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + hasFunctionCallOutput, + ) + } else if !shouldKeepPreviousResponseID { + updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) + if dropErr != nil || !removed { + dropReason := "not_removed" + if dropErr != nil { + dropReason = "drop_error" + } + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s drop_reason=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(strictReason), + normalizeOpenAIWSLogValue(dropReason), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + hasFunctionCallOutput, + ) + } else { + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + currentTurnReplayInput, + currentTurnReplayInputExists, + ) + if setInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s drop_reason=set_full_input_error previous_response_id=%s expected_previous_response_id=%s cause=%s has_function_call_output=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(strictReason), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), + hasFunctionCallOutput, + ) + } else { + currentPayload = updatedWithInput + currentPayloadBytes = len(updatedWithInput) + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_full_create reason=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(strictReason), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + hasFunctionCallOutput, + ) + currentPreviousResponseID = "" + } + } + } + } + forcePreferredConn := isStrictAffinityTurn(currentPayload) + if sessionLease == nil { + acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn) + if acquireErr != nil { + return fmt.Errorf("acquire upstream websocket: %w", acquireErr) + } + sessionLease = acquiredLease + sessionConnID = strings.TrimSpace(sessionLease.ConnID()) + if storeDisabled { + pinSessionConn(sessionConnID) + } else { + unpinSessionConn(sessionConnID) + } + } + shouldPreflightPing := turn > 1 && sessionLease != nil && turnRetry == 0 + if shouldPreflightPing && openAIWSIngressPreflightPingIdle > 0 && !lastTurnFinishedAt.IsZero() { + if time.Since(lastTurnFinishedAt) < openAIWSIngressPreflightPingIdle { + shouldPreflightPing = false + } + } + if shouldPreflightPing { + if pingErr := sessionLease.PingWithTimeout(openAIWSConnHealthCheckTO); pingErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_upstream_preflight_ping_fail account_id=%d turn=%d conn_id=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(pingErr.Error(), openAIWSLogValueMaxLen), + ) + if forcePreferredConn { + if !turnPrevRecoveryTried && currentPreviousResponseID != "" { + updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) + if dropErr != nil || !removed { + reason := "not_removed" + if dropErr != nil { + reason = "drop_error" + } + logOpenAIWSModeInfo( + "ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(reason), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + ) + } else { + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + currentTurnReplayInput, + currentTurnReplayInputExists, + ) + if setInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error previous_response_id=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), + ) + } else { + logOpenAIWSModeInfo( + "ingress_ws_preflight_ping_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_retry previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + ) + turnPrevRecoveryTried = true + currentPayload = updatedWithInput + currentPayloadBytes = len(updatedWithInput) + resetSessionLease(true) + skipBeforeTurn = true + continue + } + } + } + resetSessionLease(true) + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "upstream continuation connection is unavailable; please restart the conversation", + pingErr, + ) + } + resetSessionLease(true) + + acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn) + if acquireErr != nil { + return fmt.Errorf("acquire upstream websocket after preflight ping fail: %w", acquireErr) + } + sessionLease = acquiredLease + sessionConnID = strings.TrimSpace(sessionLease.ConnID()) + if storeDisabled { + pinSessionConn(sessionConnID) + } + } + } + connID := sessionConnID + if currentPreviousResponseID != "" { + chainedFromLast := expectedPrev != "" && currentPreviousResponseID == expectedPrev + currentPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(currentPreviousResponseID) + logOpenAIWSModeInfo( + "ingress_ws_turn_chain account_id=%d turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v preferred_conn_id=%s header_session_id=%s header_conversation_id=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v store_disabled=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(currentPreviousResponseIDKind), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + chainedFromLast, + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + openAIWSHeaderValueForLog(baseAcquireReq.Headers, "session_id"), + openAIWSHeaderValueForLog(baseAcquireReq.Headers, "conversation_id"), + turnState != "", + len(turnState), + openAIWSPayloadStringFromRaw(currentPayload, "prompt_cache_key") != "", + storeDisabled, + ) + } + + result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel) + if relayErr != nil { + if recoverIngressPrevResponseNotFound(relayErr, turn, connID) { + continue + } + if retryIngressTurn(relayErr, turn, connID) { + continue + } + finalErr := relayErr + if unwrapped := errors.Unwrap(relayErr); unwrapped != nil { + finalErr = unwrapped + } + if hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(turn, nil, finalErr) + } + sessionLease.MarkBroken() + return finalErr + } + turnRetry = 0 + turnPrevRecoveryTried = false + lastTurnFinishedAt = time.Now() + if hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(turn, result, nil) + } + if result == nil { + return errors.New("websocket turn result is nil") + } + responseID := strings.TrimSpace(result.RequestID) + lastTurnResponseID = responseID + lastTurnPayload = cloneOpenAIWSPayloadBytes(currentPayload) + lastTurnReplayInput = cloneOpenAIWSRawMessages(currentTurnReplayInput) + lastTurnReplayInputExists = currentTurnReplayInputExists + nextStrictState, strictStateErr := buildOpenAIWSIngressPreviousTurnStrictState(currentPayload) + if strictStateErr != nil { + lastTurnStrictState = nil + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_state_skip account_id=%d turn=%d conn_id=%s reason=build_error cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(strictStateErr.Error(), openAIWSLogValueMaxLen), + ) + } else { + lastTurnStrictState = nextStrictState + } + + if responseID != "" && stateStore != nil { + ttl := s.openAIWSResponseStickyTTL() + logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl)) + stateStore.BindResponseConn(responseID, connID, ttl) + } + if stateStore != nil && storeDisabled && sessionHash != "" { + stateStore.BindSessionConn(groupID, sessionHash, connID, s.openAIWSSessionStickyTTL()) + } + if connID != "" { + preferredConnID = connID + } + + nextClientMessage, readErr := readClientMessage() + if readErr != nil { + if isOpenAIWSClientDisconnectError(readErr) { + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "ingress_ws_client_closed account_id=%d conn_id=%s close_status=%s close_reason=%s", + account.ID, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + closeStatus, + truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen), + ) + return nil + } + return fmt.Errorf("read client websocket request: %w", readErr) + } + + nextPayload, parseErr := parseClientPayload(nextClientMessage) + if parseErr != nil { + return parseErr + } + if nextPayload.promptCacheKey != "" { + // ingress 会话在整个客户端 WS 生命周期内复用同一上游连接; + // prompt_cache_key 对握手头的更新仅在未来需要重新建连时生效。 + updatedHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), nextPayload.promptCacheKey) + baseAcquireReq.Headers = updatedHeaders + } + if nextPayload.previousResponseID != "" { + expectedPrev := strings.TrimSpace(lastTurnResponseID) + chainedFromLast := expectedPrev != "" && nextPayload.previousResponseID == expectedPrev + nextPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(nextPayload.previousResponseID) + logOpenAIWSModeInfo( + "ingress_ws_next_turn_chain account_id=%d turn=%d next_turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v has_prompt_cache_key=%v store_disabled=%v", + account.ID, + turn, + turn+1, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(nextPreviousResponseIDKind), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + chainedFromLast, + nextPayload.promptCacheKey != "", + storeDisabled, + ) + } + if stateStore != nil && nextPayload.previousResponseID != "" { + if stickyConnID, ok := stateStore.GetResponseConn(nextPayload.previousResponseID); ok { + if sessionConnID != "" && stickyConnID != "" && stickyConnID != sessionConnID { + logOpenAIWSModeInfo( + "ingress_ws_keep_session_conn account_id=%d turn=%d conn_id=%s sticky_conn_id=%s previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(stickyConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen), + ) + } else { + preferredConnID = stickyConnID + } + } + } + currentPayload = nextPayload.payloadRaw + currentOriginalModel = nextPayload.originalModel + currentPayloadBytes = nextPayload.payloadBytes + storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(currentPayload, account) + if !storeDisabled { + unpinSessionConn(sessionConnID) + } + turn++ + } +} + +func (s *OpenAIGatewayService) isOpenAIWSGeneratePrewarmEnabled() bool { + return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled +} + +// performOpenAIWSGeneratePrewarm 在 WSv2 下执行可选的 generate=false 预热。 +// 预热默认关闭,仅在配置开启后生效;失败时按可恢复错误回退到 HTTP。 +func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm( + ctx context.Context, + lease *openAIWSConnLease, + decision OpenAIWSProtocolDecision, + payload map[string]any, + previousResponseID string, + reqBody map[string]any, + account *Account, + stateStore OpenAIWSStateStore, + groupID int64, +) error { + if s == nil { + return nil + } + if lease == nil || account == nil { + logOpenAIWSModeInfo("prewarm_skip reason=invalid_state has_lease=%v has_account=%v", lease != nil, account != nil) + return nil + } + connID := strings.TrimSpace(lease.ConnID()) + if !s.isOpenAIWSGeneratePrewarmEnabled() { + return nil + } + if decision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + logOpenAIWSModeInfo( + "prewarm_skip account_id=%d conn_id=%s reason=transport_not_v2 transport=%s", + account.ID, + connID, + normalizeOpenAIWSLogValue(string(decision.Transport)), + ) + return nil + } + if strings.TrimSpace(previousResponseID) != "" { + logOpenAIWSModeInfo( + "prewarm_skip account_id=%d conn_id=%s reason=has_previous_response_id previous_response_id=%s", + account.ID, + connID, + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + ) + return nil + } + if lease.IsPrewarmed() { + logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=already_prewarmed", account.ID, connID) + return nil + } + if NeedsToolContinuation(reqBody) { + logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=tool_continuation", account.ID, connID) + return nil + } + prewarmStart := time.Now() + logOpenAIWSModeInfo("prewarm_start account_id=%d conn_id=%s", account.ID, connID) + + prewarmPayload := make(map[string]any, len(payload)+1) + for k, v := range payload { + prewarmPayload[k] = v + } + prewarmPayload["generate"] = false + prewarmPayloadJSON := payloadAsJSONBytes(prewarmPayload) + + if err := lease.WriteJSONWithContextTimeout(ctx, prewarmPayload, s.openAIWSWriteTimeout()); err != nil { + lease.MarkBroken() + logOpenAIWSModeInfo( + "prewarm_write_fail account_id=%d conn_id=%s cause=%s", + account.ID, + connID, + truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + ) + return wrapOpenAIWSFallback("prewarm_write", err) + } + logOpenAIWSModeInfo("prewarm_write_sent account_id=%d conn_id=%s payload_bytes=%d", account.ID, connID, len(prewarmPayloadJSON)) + + prewarmResponseID := "" + prewarmEventCount := 0 + prewarmTerminalCount := 0 + for { + message, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout()) + if readErr != nil { + lease.MarkBroken() + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "prewarm_read_fail account_id=%d conn_id=%s close_status=%s close_reason=%s cause=%s events=%d", + account.ID, + connID, + closeStatus, + closeReason, + truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen), + prewarmEventCount, + ) + return wrapOpenAIWSFallback("prewarm_"+classifyOpenAIWSReadFallbackReason(readErr), readErr) + } + + eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(message) + if eventType == "" { + continue + } + prewarmEventCount++ + if prewarmResponseID == "" && eventResponseID != "" { + prewarmResponseID = eventResponseID + } + if prewarmEventCount <= openAIWSPrewarmEventLogHead || eventType == "error" || isOpenAIWSTerminalEvent(eventType) { + logOpenAIWSModeInfo( + "prewarm_event account_id=%d conn_id=%s idx=%d type=%s bytes=%d", + account.ID, + connID, + prewarmEventCount, + truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), + len(message), + ) + } + + if eventType == "error" { + errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + errMsg := strings.TrimSpace(errMsgRaw) + if errMsg == "" { + errMsg = "OpenAI websocket prewarm error" + } + fallbackReason, canFallback := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + logOpenAIWSModeInfo( + "prewarm_error_event account_id=%d conn_id=%s idx=%d fallback_reason=%s can_fallback=%v err_code=%s err_type=%s err_message=%s", + account.ID, + connID, + prewarmEventCount, + truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), + canFallback, + errCode, + errType, + errMessage, + ) + lease.MarkBroken() + if canFallback { + return wrapOpenAIWSFallback("prewarm_"+fallbackReason, errors.New(errMsg)) + } + return wrapOpenAIWSFallback("prewarm_error_event", errors.New(errMsg)) + } + + if isOpenAIWSTerminalEvent(eventType) { + prewarmTerminalCount++ + break + } + } + + lease.MarkPrewarmed() + if prewarmResponseID != "" && stateStore != nil { + ttl := s.openAIWSResponseStickyTTL() + logOpenAIWSBindResponseAccountWarn(groupID, account.ID, prewarmResponseID, stateStore.BindResponseAccount(ctx, groupID, prewarmResponseID, account.ID, ttl)) + stateStore.BindResponseConn(prewarmResponseID, lease.ConnID(), ttl) + } + logOpenAIWSModeInfo( + "prewarm_done account_id=%d conn_id=%s response_id=%s events=%d terminal_events=%d duration_ms=%d", + account.ID, + connID, + truncateOpenAIWSLogValue(prewarmResponseID, openAIWSIDValueMaxLen), + prewarmEventCount, + prewarmTerminalCount, + time.Since(prewarmStart).Milliseconds(), + ) + return nil +} + +func payloadAsJSON(payload map[string]any) string { + return string(payloadAsJSONBytes(payload)) +} + +func payloadAsJSONBytes(payload map[string]any) []byte { + if len(payload) == 0 { + return []byte("{}") + } + body, err := json.Marshal(payload) + if err != nil { + return []byte("{}") + } + return body +} + +func isOpenAIWSTerminalEvent(eventType string) bool { + switch strings.TrimSpace(eventType) { + case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": + return true + default: + return false + } +} + +func isOpenAIWSTokenEvent(eventType string) bool { + eventType = strings.TrimSpace(eventType) + if eventType == "" { + return false + } + switch eventType { + case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done": + return false + } + if strings.Contains(eventType, ".delta") { + return true + } + if strings.HasPrefix(eventType, "response.output_text") { + return true + } + if strings.HasPrefix(eventType, "response.output") { + return true + } + return eventType == "response.completed" || eventType == "response.done" +} + +func replaceOpenAIWSMessageModel(message []byte, fromModel, toModel string) []byte { + if len(message) == 0 { + return message + } + if strings.TrimSpace(fromModel) == "" || strings.TrimSpace(toModel) == "" || fromModel == toModel { + return message + } + if !bytes.Contains(message, []byte(`"model"`)) || !bytes.Contains(message, []byte(fromModel)) { + return message + } + modelValues := gjson.GetManyBytes(message, "model", "response.model") + replaceModel := modelValues[0].Exists() && modelValues[0].Str == fromModel + replaceResponseModel := modelValues[1].Exists() && modelValues[1].Str == fromModel + if !replaceModel && !replaceResponseModel { + return message + } + updated := message + if replaceModel { + if next, err := sjson.SetBytes(updated, "model", toModel); err == nil { + updated = next + } + } + if replaceResponseModel { + if next, err := sjson.SetBytes(updated, "response.model", toModel); err == nil { + updated = next + } + } + return updated +} + +func populateOpenAIUsageFromResponseJSON(body []byte, usage *OpenAIUsage) { + if usage == nil || len(body) == 0 { + return + } + values := gjson.GetManyBytes( + body, + "usage.input_tokens", + "usage.output_tokens", + "usage.input_tokens_details.cached_tokens", + ) + usage.InputTokens = int(values[0].Int()) + usage.OutputTokens = int(values[1].Int()) + usage.CacheReadInputTokens = int(values[2].Int()) +} + +func getOpenAIGroupIDFromContext(c *gin.Context) int64 { + if c == nil { + return 0 + } + value, exists := c.Get("api_key") + if !exists { + return 0 + } + apiKey, ok := value.(*APIKey) + if !ok || apiKey == nil || apiKey.GroupID == nil { + return 0 + } + return *apiKey.GroupID +} + +// SelectAccountByPreviousResponseID 按 previous_response_id 命中账号粘连。 +// 未命中或账号不可用时返回 (nil, nil),由调用方继续走常规调度。 +func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( + ctx context.Context, + groupID *int64, + previousResponseID string, + requestedModel string, + excludedIDs map[int64]struct{}, +) (*AccountSelectionResult, error) { + if s == nil { + return nil, nil + } + responseID := strings.TrimSpace(previousResponseID) + if responseID == "" { + return nil, nil + } + store := s.getOpenAIWSStateStore() + if store == nil { + return nil, nil + } + + accountID, err := store.GetResponseAccount(ctx, derefGroupID(groupID), responseID) + if err != nil || accountID <= 0 { + return nil, nil + } + if excludedIDs != nil { + if _, excluded := excludedIDs[accountID]; excluded { + return nil, nil + } + } + + account, err := s.getSchedulableAccount(ctx, accountID) + if err != nil || account == nil { + _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + return nil, nil + } + // 非 WSv2 场景(如 force_http/全局关闭)不应使用 previous_response_id 粘连, + // 以保持“回滚到 HTTP”后的历史行为一致性。 + if s.getOpenAIWSProtocolResolver().Resolve(account).Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + return nil, nil + } + if shouldClearStickySession(account, requestedModel) || !account.IsOpenAI() { + _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + return nil, nil + } + if requestedModel != "" && !account.IsModelSupported(requestedModel) { + return nil, nil + } + + result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if acquireErr == nil && result.Acquired { + logOpenAIWSBindResponseAccountWarn( + derefGroupID(groupID), + accountID, + responseID, + store.BindResponseAccount(ctx, derefGroupID(groupID), responseID, accountID, s.openAIWSResponseStickyTTL()), + ) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + + cfg := s.schedulingConfig() + if s.concurrencyService != nil { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + return nil, nil +} + +func classifyOpenAIWSAcquireError(err error) string { + if err == nil { + return "acquire_conn" + } + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) { + switch dialErr.StatusCode { + case 426: + return "upgrade_required" + case 401, 403: + return "auth_failed" + case 429: + return "upstream_rate_limited" + } + if dialErr.StatusCode >= 500 { + return "upstream_5xx" + } + return "dial_failed" + } + if errors.Is(err, errOpenAIWSConnQueueFull) { + return "conn_queue_full" + } + if errors.Is(err, errOpenAIWSPreferredConnUnavailable) { + return "preferred_conn_unavailable" + } + if errors.Is(err, context.DeadlineExceeded) { + return "acquire_timeout" + } + return "acquire_conn" +} + +func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) { + code := strings.ToLower(strings.TrimSpace(codeRaw)) + errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) + msg := strings.ToLower(strings.TrimSpace(msgRaw)) + + switch code { + case "upgrade_required": + return "upgrade_required", true + case "websocket_not_supported", "websocket_unsupported": + return "ws_unsupported", true + case "websocket_connection_limit_reached": + return "ws_connection_limit_reached", true + case "previous_response_not_found": + return "previous_response_not_found", true + } + if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") { + return "upgrade_required", true + } + if strings.Contains(errType, "upgrade") { + return "upgrade_required", true + } + if strings.Contains(msg, "websocket") && strings.Contains(msg, "unsupported") { + return "ws_unsupported", true + } + if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") { + return "ws_connection_limit_reached", true + } + if strings.Contains(msg, "previous_response_not_found") || + (strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) { + return "previous_response_not_found", true + } + if strings.Contains(errType, "server_error") || strings.Contains(code, "server_error") { + return "upstream_error_event", true + } + return "event_error", false +} + +func classifyOpenAIWSErrorEvent(message []byte) (string, bool) { + if len(message) == 0 { + return "event_error", false + } + return classifyOpenAIWSErrorEventFromRaw(parseOpenAIWSErrorEventFields(message)) +} + +func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int { + code := strings.ToLower(strings.TrimSpace(codeRaw)) + errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) + switch { + case strings.Contains(errType, "invalid_request"), + strings.Contains(code, "invalid_request"), + strings.Contains(code, "bad_request"), + code == "previous_response_not_found": + return http.StatusBadRequest + case strings.Contains(errType, "authentication"), + strings.Contains(code, "invalid_api_key"), + strings.Contains(code, "unauthorized"): + return http.StatusUnauthorized + case strings.Contains(errType, "permission"), + strings.Contains(code, "forbidden"): + return http.StatusForbidden + case strings.Contains(errType, "rate_limit"), + strings.Contains(code, "rate_limit"), + strings.Contains(code, "insufficient_quota"): + return http.StatusTooManyRequests + default: + return http.StatusBadGateway + } +} + +func openAIWSErrorHTTPStatus(message []byte) int { + if len(message) == 0 { + return http.StatusBadGateway + } + codeRaw, errTypeRaw, _ := parseOpenAIWSErrorEventFields(message) + return openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) +} + +func (s *OpenAIGatewayService) openAIWSFallbackCooldown() time.Duration { + if s == nil || s.cfg == nil { + return 30 * time.Second + } + seconds := s.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds + if seconds <= 0 { + return 0 + } + return time.Duration(seconds) * time.Second +} + +func (s *OpenAIGatewayService) isOpenAIWSFallbackCooling(accountID int64) bool { + if s == nil || accountID <= 0 { + return false + } + cooldown := s.openAIWSFallbackCooldown() + if cooldown <= 0 { + return false + } + rawUntil, ok := s.openaiWSFallbackUntil.Load(accountID) + if !ok || rawUntil == nil { + return false + } + until, ok := rawUntil.(time.Time) + if !ok || until.IsZero() { + s.openaiWSFallbackUntil.Delete(accountID) + return false + } + if time.Now().Before(until) { + return true + } + s.openaiWSFallbackUntil.Delete(accountID) + return false +} + +func (s *OpenAIGatewayService) markOpenAIWSFallbackCooling(accountID int64, _ string) { + if s == nil || accountID <= 0 { + return + } + cooldown := s.openAIWSFallbackCooldown() + if cooldown <= 0 { + return + } + s.openaiWSFallbackUntil.Store(accountID, time.Now().Add(cooldown)) +} + +func (s *OpenAIGatewayService) clearOpenAIWSFallbackCooling(accountID int64) { + if s == nil || accountID <= 0 { + return + } + s.openaiWSFallbackUntil.Delete(accountID) +} diff --git a/backend/internal/service/openai_ws_forwarder_benchmark_test.go b/backend/internal/service/openai_ws_forwarder_benchmark_test.go new file mode 100644 index 00000000..bd03ab5a --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_benchmark_test.go @@ -0,0 +1,127 @@ +package service + +import ( + "fmt" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +var ( + benchmarkOpenAIWSPayloadJSONSink string + benchmarkOpenAIWSStringSink string + benchmarkOpenAIWSBoolSink bool + benchmarkOpenAIWSBytesSink []byte +) + +func BenchmarkOpenAIWSForwarderHotPath(b *testing.B) { + cfg := &config.Config{} + svc := &OpenAIGatewayService{cfg: cfg} + account := &Account{ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth} + reqBody := benchmarkOpenAIWSHotPathRequest() + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + payload := svc.buildOpenAIWSCreatePayload(reqBody, account) + _, _ = applyOpenAIWSRetryPayloadStrategy(payload, 2) + setOpenAIWSTurnMetadata(payload, `{"trace":"bench","turn":"1"}`) + + benchmarkOpenAIWSStringSink = openAIWSPayloadString(payload, "previous_response_id") + benchmarkOpenAIWSBoolSink = payload["tools"] != nil + benchmarkOpenAIWSStringSink = summarizeOpenAIWSPayloadKeySizes(payload, openAIWSPayloadKeySizeTopN) + benchmarkOpenAIWSStringSink = summarizeOpenAIWSInput(payload["input"]) + benchmarkOpenAIWSPayloadJSONSink = payloadAsJSON(payload) + } +} + +func benchmarkOpenAIWSHotPathRequest() map[string]any { + tools := make([]map[string]any, 0, 24) + for i := 0; i < 24; i++ { + tools = append(tools, map[string]any{ + "type": "function", + "name": fmt.Sprintf("tool_%02d", i), + "description": "benchmark tool schema", + "parameters": map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{"type": "string"}, + "limit": map[string]any{"type": "number"}, + }, + "required": []string{"query"}, + }, + }) + } + + input := make([]map[string]any, 0, 16) + for i := 0; i < 16; i++ { + input = append(input, map[string]any{ + "role": "user", + "type": "message", + "content": fmt.Sprintf("benchmark message %d", i), + }) + } + + return map[string]any{ + "type": "response.create", + "model": "gpt-5.3-codex", + "input": input, + "tools": tools, + "parallel_tool_calls": true, + "previous_response_id": "resp_benchmark_prev", + "prompt_cache_key": "bench-cache-key", + "reasoning": map[string]any{"effort": "medium"}, + "instructions": "benchmark instructions", + "store": false, + } +} + +func BenchmarkOpenAIWSEventEnvelopeParse(b *testing.B) { + event := []byte(`{"type":"response.completed","response":{"id":"resp_bench_1","model":"gpt-5.1","usage":{"input_tokens":12,"output_tokens":8}}}`) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + eventType, responseID, response := parseOpenAIWSEventEnvelope(event) + benchmarkOpenAIWSStringSink = eventType + benchmarkOpenAIWSStringSink = responseID + benchmarkOpenAIWSBoolSink = response.Exists() + } +} + +func BenchmarkOpenAIWSErrorEventFieldReuse(b *testing.B) { + event := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(event) + benchmarkOpenAIWSStringSink, benchmarkOpenAIWSBoolSink = classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, errMsgRaw) + code, errType, errMsg := summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMsgRaw) + benchmarkOpenAIWSStringSink = code + benchmarkOpenAIWSStringSink = errType + benchmarkOpenAIWSStringSink = errMsg + benchmarkOpenAIWSBoolSink = openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) > 0 + } +} + +func BenchmarkReplaceOpenAIWSMessageModel_NoMatchFastPath(b *testing.B) { + event := []byte(`{"type":"response.output_text.delta","delta":"hello world"}`) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSBytesSink = replaceOpenAIWSMessageModel(event, "gpt-5.1", "custom-model") + } +} + +func BenchmarkReplaceOpenAIWSMessageModel_DualReplace(b *testing.B) { + event := []byte(`{"type":"response.completed","model":"gpt-5.1","response":{"id":"resp_1","model":"gpt-5.1"}}`) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSBytesSink = replaceOpenAIWSMessageModel(event, "gpt-5.1", "custom-model") + } +} diff --git a/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go new file mode 100644 index 00000000..76167603 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go @@ -0,0 +1,73 @@ +package service + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseOpenAIWSEventEnvelope(t *testing.T) { + eventType, responseID, response := parseOpenAIWSEventEnvelope([]byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.1"}}`)) + require.Equal(t, "response.completed", eventType) + require.Equal(t, "resp_1", responseID) + require.True(t, response.Exists()) + require.Equal(t, `{"id":"resp_1","model":"gpt-5.1"}`, response.Raw) + + eventType, responseID, response = parseOpenAIWSEventEnvelope([]byte(`{"type":"response.delta","id":"evt_1"}`)) + require.Equal(t, "response.delta", eventType) + require.Equal(t, "evt_1", responseID) + require.False(t, response.Exists()) +} + +func TestParseOpenAIWSResponseUsageFromCompletedEvent(t *testing.T) { + usage := &OpenAIUsage{} + parseOpenAIWSResponseUsageFromCompletedEvent( + []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":11,"output_tokens":7,"input_tokens_details":{"cached_tokens":3}}}}`), + usage, + ) + require.Equal(t, 11, usage.InputTokens) + require.Equal(t, 7, usage.OutputTokens) + require.Equal(t, 3, usage.CacheReadInputTokens) +} + +func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) { + message := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`) + codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + + wrappedReason, wrappedRecoverable := classifyOpenAIWSErrorEvent(message) + rawReason, rawRecoverable := classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, errMsgRaw) + require.Equal(t, wrappedReason, rawReason) + require.Equal(t, wrappedRecoverable, rawRecoverable) + + wrappedStatus := openAIWSErrorHTTPStatus(message) + rawStatus := openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) + require.Equal(t, wrappedStatus, rawStatus) + require.Equal(t, http.StatusBadRequest, rawStatus) + + wrappedCode, wrappedType, wrappedMsg := summarizeOpenAIWSErrorEventFields(message) + rawCode, rawType, rawMsg := summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMsgRaw) + require.Equal(t, wrappedCode, rawCode) + require.Equal(t, wrappedType, rawType) + require.Equal(t, wrappedMsg, rawMsg) +} + +func TestOpenAIWSMessageLikelyContainsToolCalls(t *testing.T) { + require.False(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_text.delta","delta":"hello"}`))) + require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"tool_calls":[{"id":"tc1"}]}}`))) + require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"type":"function_call"}}`))) +} + +func TestReplaceOpenAIWSMessageModel_OptimizedStillCorrect(t *testing.T) { + noModel := []byte(`{"type":"response.output_text.delta","delta":"hello"}`) + require.Equal(t, string(noModel), string(replaceOpenAIWSMessageModel(noModel, "gpt-5.1", "custom-model"))) + + rootOnly := []byte(`{"type":"response.created","model":"gpt-5.1"}`) + require.Equal(t, `{"type":"response.created","model":"custom-model"}`, string(replaceOpenAIWSMessageModel(rootOnly, "gpt-5.1", "custom-model"))) + + responseOnly := []byte(`{"type":"response.completed","response":{"model":"gpt-5.1"}}`) + require.Equal(t, `{"type":"response.completed","response":{"model":"custom-model"}}`, string(replaceOpenAIWSMessageModel(responseOnly, "gpt-5.1", "custom-model"))) + + both := []byte(`{"model":"gpt-5.1","response":{"model":"gpt-5.1"}}`) + require.Equal(t, `{"model":"custom-model","response":{"model":"custom-model"}}`, string(replaceOpenAIWSMessageModel(both, "gpt-5.1", "custom-model"))) +} diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go new file mode 100644 index 00000000..5a3c12c3 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go @@ -0,0 +1,2483 @@ +package service + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossTurns(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_ingress_turn_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_ingress_turn_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 114, + Name: "openai-ingress-session-lease", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + turnWSModeCh := make(chan bool, 2) + hooks := &OpenAIWSIngressHooks{ + AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) { + if turnErr == nil && result != nil { + turnWSModeCh <- result.OpenAIWSMode + } + }, + } + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`) + firstTurnEvent := readMessage() + require.Equal(t, "response.completed", gjson.GetBytes(firstTurnEvent, "type").String()) + require.Equal(t, "resp_ingress_turn_1", gjson.GetBytes(firstTurnEvent, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_ingress_turn_1"}`) + secondTurnEvent := readMessage() + require.Equal(t, "response.completed", gjson.GetBytes(secondTurnEvent, "type").String()) + require.Equal(t, "resp_ingress_turn_2", gjson.GetBytes(secondTurnEvent, "response.id").String()) + require.True(t, <-turnWSModeCh, "首轮 turn 应标记为 WS 模式") + require.True(t, <-turnWSModeCh, "第二轮 turn 应标记为 WS 模式") + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + metrics := svc.SnapshotOpenAIWSPoolMetrics() + require.Equal(t, int64(1), metrics.AcquireTotal, "同一 ingress 会话多 turn 应只获取一次上游 lease") + require.Equal(t, 1, captureDialer.DialCount(), "同一 ingress 会话应保持同一上游连接") + require.Len(t, captureConn.writes, 2, "应向同一上游连接发送两轮 response.create") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoesNotReuseConnAcrossSessions(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + upstreamConn1 := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_dedicated_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + upstreamConn2 := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_dedicated_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{upstreamConn1, upstreamConn2}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 441, + Name: "openai-ingress-dedicated", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + }, + } + + serverErrCh := make(chan error, 2) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + runSingleTurnSession := func(expectedResponseID string) { + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`)) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + msgType, event, readErr := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + require.Equal(t, expectedResponseID, gjson.GetBytes(event, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + } + + runSingleTurnSession("resp_dedicated_1") + runSingleTurnSession("resp_dedicated_2") + + require.Equal(t, 2, dialer.DialCount(), "dedicated 模式下跨客户端会话不应复用上游连接") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: newOpenAIWSConnPool(cfg), + } + + account := &Account{ + ID: 442, + Name: "openai-ingress-off", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeOff, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`)) + cancelWrite() + require.NoError(t, err) + + select { + case serverErr := <-serverErrCh: + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, serverErr, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) + require.Equal(t, "websocket mode is disabled for this account", closeErr.Reason()) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPrevResponseStrictDropToFullCreate(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_preflight_rewrite_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_preflight_rewrite_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 140, + Name: "openai-ingress-prev-preflight-rewrite", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_preflight_rewrite_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_stale_external","input":[{"type":"input_text","text":"world"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_preflight_rewrite_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount(), "严格增量不成立时应在同一连接内降级为 full create") + require.Len(t, captureConn.writes, 2) + secondWrite := requestToJSONString(captureConn.writes[1]) + require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "严格增量不成立时应移除 previous_response_id,改为 full create") + require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()), "严格降级为 full create 时应重放完整 input 上下文") + require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String()) + require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String()) +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPrevResponseStrictDropBeforePreflightPingFailReconnects(t *testing.T) { + gin.SetMode(gin.TestMode) + prevPreflightPingIdle := openAIWSIngressPreflightPingIdle + openAIWSIngressPreflightPingIdle = 0 + defer func() { + openAIWSIngressPreflightPingIdle = prevPreflightPingIdle + }() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSPreflightFailConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_drop_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_drop_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 142, + Name: "openai-ingress-prev-strict-drop-before-ping", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_ping_drop_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_stale_external","input":[{"type":"input_text","text":"world"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_turn_ping_drop_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 严格降级后预检换连超时") + } + + require.Equal(t, 2, dialer.DialCount(), "严格降级为 full create 后,预检 ping 失败应允许换连") + require.Equal(t, 1, firstConn.WriteCount(), "首连接在预检失败后不应继续发送第二轮") + require.GreaterOrEqual(t, firstConn.PingCount(), 1, "第二轮前应执行 preflight ping") + secondConn.mu.Lock() + secondWrites := append([]map[string]any(nil), secondConn.writes...) + secondConn.mu.Unlock() + require.Len(t, secondWrites, 1) + secondWrite := requestToJSONString(secondWrites[0]) + require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "严格降级后重试应移除 previous_response_id") + require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array())) + require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String()) + require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String()) +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreEnabledSkipsStrictPrevResponseEval(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_store_enabled_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_store_enabled_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 143, + Name: "openai-ingress-store-enabled-skip-strict", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":true}`) + firstTurn := readMessage() + require.Equal(t, "resp_store_enabled_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":true,"previous_response_id":"resp_stale_external"}`) + secondTurn := readMessage() + require.Equal(t, "resp_store_enabled_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 store=true 场景 websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount()) + require.Len(t, captureConn.writes, 2) + require.Equal(t, "resp_stale_external", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "store=true 场景不应触发 store-disabled strict 规则") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPrevResponsePreflightSkipForFunctionCallOutput(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_preflight_skip_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_preflight_skip_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 141, + Name: "openai-ingress-prev-preflight-skip-fco", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false}`) + firstTurn := readMessage() + require.Equal(t, "resp_preflight_skip_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_stale_external","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_preflight_skip_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount()) + require.Len(t, captureConn.writes, 2) + require.Equal(t, "resp_stale_external", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "function_call_output 场景不应预改写 previous_response_id") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputAutoAttachPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 143, + Name: "openai-ingress-fco-auto-prev", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_auto_prev_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call_output","call_id":"call_auto_1","output":"ok"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_auto_prev_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount()) + require.Len(t, captureConn.writes, 2) + require.Equal(t, "resp_auto_prev_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "function_call_output 缺失 previous_response_id 时应回填上一轮响应 ID") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenLastResponseIDMissing(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_skip_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 144, + Name: "openai-ingress-fco-auto-prev-skip", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "response.completed", gjson.GetBytes(firstTurn, "type").String()) + require.Empty(t, gjson.GetBytes(firstTurn, "response.id").String(), "首轮响应不返回 response.id,模拟无法推导续链锚点") + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call_output","call_id":"call_auto_skip_1","output":"ok"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_auto_prev_skip_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount()) + require.Len(t, captureConn.writes, 2) + require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "上一轮缺失 response.id 时不应自动补齐 previous_response_id") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) { + gin.SetMode(gin.TestMode) + prevPreflightPingIdle := openAIWSIngressPreflightPingIdle + openAIWSIngressPreflightPingIdle = 0 + defer func() { + openAIWSIngressPreflightPingIdle = prevPreflightPingIdle + }() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSPreflightFailConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 116, + Name: "openai-ingress-preflight-ping", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_ping_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_ping_1"}`) + secondTurn := readMessage() + require.Equal(t, "resp_turn_ping_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + require.Equal(t, 2, dialer.DialCount(), "第二轮 turn 前 ping 失败应触发换连") + require.Equal(t, 1, firstConn.WriteCount(), "preflight ping 失败后不应继续向旧连接发送第二轮 turn") + require.GreaterOrEqual(t, firstConn.PingCount(), 1, "第二轮前应对旧连接执行 preflight ping") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledStrictAffinityPreflightPingFailAutoRecoveryReconnects(t *testing.T) { + gin.SetMode(gin.TestMode) + prevPreflightPingIdle := openAIWSIngressPreflightPingIdle + openAIWSIngressPreflightPingIdle = 0 + defer func() { + openAIWSIngressPreflightPingIdle = prevPreflightPingIdle + }() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSPreflightFailConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_strict_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_strict_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 121, + Name: "openai-ingress-preflight-ping-strict-affinity", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_ping_strict_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_turn_ping_strict_1","input":[{"type":"input_text","text":"world"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_turn_ping_strict_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 严格亲和自动恢复后结束超时") + } + + require.Equal(t, 2, dialer.DialCount(), "严格亲和 preflight ping 失败后应自动降级并换连重放") + require.Equal(t, 1, firstConn.WriteCount(), "preflight ping 失败后不应继续在旧连接写第二轮") + require.GreaterOrEqual(t, firstConn.PingCount(), 1, "第二轮前应执行 preflight ping") + secondConn.mu.Lock() + secondWrites := append([]map[string]any(nil), secondConn.writes...) + secondConn.mu.Unlock() + require.Len(t, secondWrites, 1) + secondWrite := requestToJSONString(secondWrites[0]) + require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "自动恢复重放应移除 previous_response_id") + require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()), "自动恢复重放应使用完整 input 上下文") + require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String()) + require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String()) +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_WriteFailBeforeDownstreamRetriesOnce(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSWriteFailAfterFirstTurnConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_write_retry_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_write_retry_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 117, + Name: "openai-ingress-write-retry", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + var hooksMu sync.Mutex + beforeTurnCalls := make(map[int]int) + afterTurnCalls := make(map[int]int) + hooks := &OpenAIWSIngressHooks{ + BeforeTurn: func(turn int) error { + hooksMu.Lock() + beforeTurnCalls[turn]++ + hooksMu.Unlock() + return nil + }, + AfterTurn: func(turn int, _ *OpenAIForwardResult, _ error) { + hooksMu.Lock() + afterTurnCalls[turn]++ + hooksMu.Unlock() + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_write_retry_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_write_retry_1"}`) + secondTurn := readMessage() + require.Equal(t, "resp_turn_write_retry_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + require.Equal(t, 2, dialer.DialCount(), "第二轮 turn 上游写失败且未写下游时应自动重试并换连") + hooksMu.Lock() + beforeTurn1 := beforeTurnCalls[1] + beforeTurn2 := beforeTurnCalls[2] + afterTurn1 := afterTurnCalls[1] + afterTurn2 := afterTurnCalls[2] + hooksMu.Unlock() + require.Equal(t, 1, beforeTurn1, "首轮 turn BeforeTurn 应执行一次") + require.Equal(t, 1, beforeTurn2, "同一 turn 重试不应重复触发 BeforeTurn") + require.Equal(t, 1, afterTurn1, "首轮 turn AfterTurn 应执行一次") + require.Equal(t, 1, afterTurn2, "第二轮 turn AfterTurn 应执行一次") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreviousResponseNotFoundRecoversByDroppingPrevID(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_recover_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":""}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_recover_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 118, + Name: "openai-ingress-prev-recovery", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_seed_anchor"}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_prev_recover_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_prev_recover_1"}`) + secondTurn := readMessage() + require.Equal(t, "response.completed", gjson.GetBytes(secondTurn, "type").String()) + require.Equal(t, "resp_turn_prev_recover_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 2, dialer.DialCount(), "previous_response_not_found 恢复应触发换连重试") + + firstConn.mu.Lock() + firstWrites := append([]map[string]any(nil), firstConn.writes...) + firstConn.mu.Unlock() + require.Len(t, firstWrites, 2, "首个连接应处理首轮与失败的第二轮请求") + require.True(t, gjson.Get(requestToJSONString(firstWrites[1]), "previous_response_id").Exists(), "失败轮次首发请求应包含 previous_response_id") + + secondConn.mu.Lock() + secondWrites := append([]map[string]any(nil), secondConn.writes...) + secondConn.mu.Unlock() + require.Len(t, secondWrites, 1, "恢复重试应在第二个连接发送一次请求") + require.False(t, gjson.Get(requestToJSONString(secondWrites[0]), "previous_response_id").Exists(), "恢复重试应移除 previous_response_id") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledStrictAffinityPreviousResponseNotFoundLayer2Recovery(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_strict_recover_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":"missing strict anchor"}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_strict_recover_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 122, + Name: "openai-ingress-prev-strict-layer2", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"prompt_cache_key":"pk_strict_layer2","input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_prev_strict_recover_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"prompt_cache_key":"pk_strict_layer2","previous_response_id":"resp_turn_prev_strict_recover_1","input":[{"type":"input_text","text":"world"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_turn_prev_strict_recover_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 严格亲和 Layer2 恢复结束超时") + } + + require.Equal(t, 2, dialer.DialCount(), "严格亲和链路命中 previous_response_not_found 应触发 Layer2 恢复重试") + + firstConn.mu.Lock() + firstWrites := append([]map[string]any(nil), firstConn.writes...) + firstConn.mu.Unlock() + require.Len(t, firstWrites, 2, "首连接应收到首轮请求和失败的续链请求") + require.True(t, gjson.Get(requestToJSONString(firstWrites[1]), "previous_response_id").Exists()) + + secondConn.mu.Lock() + secondWrites := append([]map[string]any(nil), secondConn.writes...) + secondConn.mu.Unlock() + require.Len(t, secondWrites, 1, "Layer2 恢复应仅重放一次") + secondWrite := requestToJSONString(secondWrites[0]) + require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "Layer2 恢复重放应移除 previous_response_id") + require.True(t, gjson.Get(secondWrite, "store").Exists(), "Layer2 恢复不应改变 store 标志") + require.False(t, gjson.Get(secondWrite, "store").Bool()) + require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()), "Layer2 恢复应重放完整 input 上下文") + require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String()) + require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String()) +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreviousResponseNotFoundRecoveryRemovesDuplicatePrevID(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_once_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":"first missing"}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_once_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 120, + Name: "openai-ingress-prev-recovery-once", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_prev_once_1", gjson.GetBytes(firstTurn, "response.id").String()) + + // duplicate previous_response_id: 恢复重试时应删除所有重复键,避免再次 previous_response_not_found。 + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_prev_once_1","input":[],"previous_response_id":"resp_turn_prev_duplicate"}`) + secondTurn := readMessage() + require.Equal(t, "resp_turn_prev_once_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 2, dialer.DialCount(), "previous_response_not_found 恢复应只重试一次") + + firstConn.mu.Lock() + firstWrites := append([]map[string]any(nil), firstConn.writes...) + firstConn.mu.Unlock() + require.Len(t, firstWrites, 2) + require.True(t, gjson.Get(requestToJSONString(firstWrites[1]), "previous_response_id").Exists()) + + secondConn.mu.Lock() + secondWrites := append([]map[string]any(nil), secondConn.writes...) + secondConn.mu.Unlock() + require.Len(t, secondWrites, 1) + require.False(t, gjson.Get(requestToJSONString(secondWrites[0]), "previous_response_id").Exists(), "重复键场景恢复重试后不应保留 previous_response_id") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_RejectsMessageIDAsPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 119, + Name: "openai-ingress-prev-validation", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_123456"}`)) + cancelWrite() + require.NoError(t, err) + + select { + case serverErr := <-serverErrCh: + require.Error(t, serverErr) + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, serverErr, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) + require.Contains(t, closeErr.Reason(), "previous_response_id must be a response.id") + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } +} + +type openAIWSQueueDialer struct { + mu sync.Mutex + conns []openAIWSClientConn + dialCount int +} + +func (d *openAIWSQueueDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = headers + _ = proxyURL + d.mu.Lock() + defer d.mu.Unlock() + d.dialCount++ + if len(d.conns) == 0 { + return nil, 503, nil, errors.New("no test conn") + } + conn := d.conns[0] + if len(d.conns) > 1 { + d.conns = d.conns[1:] + } + return conn, 0, nil, nil +} + +func (d *openAIWSQueueDialer) DialCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.dialCount +} + +type openAIWSPreflightFailConn struct { + mu sync.Mutex + events [][]byte + pingFails bool + writeCount int + pingCount int +} + +func (c *openAIWSPreflightFailConn) WriteJSON(context.Context, any) error { + c.mu.Lock() + c.writeCount++ + c.mu.Unlock() + return nil +} + +func (c *openAIWSPreflightFailConn) ReadMessage(context.Context) ([]byte, error) { + c.mu.Lock() + defer c.mu.Unlock() + if len(c.events) == 0 { + return nil, io.EOF + } + event := c.events[0] + c.events = c.events[1:] + if len(c.events) == 0 { + c.pingFails = true + } + return event, nil +} + +func (c *openAIWSPreflightFailConn) Ping(context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + c.pingCount++ + if c.pingFails { + return errors.New("preflight ping failed") + } + return nil +} + +func (c *openAIWSPreflightFailConn) Close() error { + return nil +} + +func (c *openAIWSPreflightFailConn) WriteCount() int { + c.mu.Lock() + defer c.mu.Unlock() + return c.writeCount +} + +func (c *openAIWSPreflightFailConn) PingCount() int { + c.mu.Lock() + defer c.mu.Unlock() + return c.pingCount +} + +type openAIWSWriteFailAfterFirstTurnConn struct { + mu sync.Mutex + events [][]byte + failOnWrite bool +} + +func (c *openAIWSWriteFailAfterFirstTurnConn) WriteJSON(context.Context, any) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.failOnWrite { + return errors.New("write failed on stale conn") + } + return nil +} + +func (c *openAIWSWriteFailAfterFirstTurnConn) ReadMessage(context.Context) ([]byte, error) { + c.mu.Lock() + defer c.mu.Unlock() + if len(c.events) == 0 { + return nil, io.EOF + } + event := c.events[0] + c.events = c.events[1:] + if len(c.events) == 0 { + c.failOnWrite = true + } + return event, nil +} + +func (c *openAIWSWriteFailAfterFirstTurnConn) Ping(context.Context) error { + return nil +} + +func (c *openAIWSWriteFailAfterFirstTurnConn) Close() error { + return nil +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ClientDisconnectStillDrainsUpstream(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + // 多个上游事件:前几个为非 terminal 事件,最后一个为 terminal。 + // 第一个事件延迟 250ms 让客户端 RST 有时间传播,使 writeClientMessage 可靠失败。 + captureConn := &openAIWSCaptureConn{ + readDelays: []time.Duration{250 * time.Millisecond, 0, 0}, + events: [][]byte{ + []byte(`{"type":"response.created","response":{"id":"resp_ingress_disconnect","model":"gpt-5.1"}}`), + []byte(`{"type":"response.output_item.added","response":{"id":"resp_ingress_disconnect"}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_ingress_disconnect","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 115, + Name: "openai-ingress-client-disconnect", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "model_mapping": map[string]any{ + "custom-original-model": "gpt-5.1", + }, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + resultCh := make(chan *OpenAIForwardResult, 1) + hooks := &OpenAIWSIngressHooks{ + AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) { + if turnErr == nil && result != nil { + resultCh <- result + } + }, + } + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"custom-original-model","stream":false}`)) + cancelWrite() + require.NoError(t, err) + // 立即关闭客户端,模拟客户端在 relay 期间断连。 + require.NoError(t, clientConn.CloseNow(), "模拟 ingress 客户端提前断连") + + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr, "客户端断连后应继续 drain 上游直到 terminal 或正常结束") + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + select { + case result := <-resultCh: + require.Equal(t, "resp_ingress_disconnect", result.RequestID) + require.Equal(t, 2, result.Usage.InputTokens) + require.Equal(t, 1, result.Usage.OutputTokens) + case <-time.After(2 * time.Second): + t.Fatal("未收到断连后的 turn 结果回调") + } +} diff --git a/backend/internal/service/openai_ws_forwarder_ingress_test.go b/backend/internal/service/openai_ws_forwarder_ingress_test.go new file mode 100644 index 00000000..ff35cb01 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_ingress_test.go @@ -0,0 +1,714 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "io" + "net" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + coderws "github.com/coder/websocket" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestIsOpenAIWSClientDisconnectError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + want bool + }{ + {name: "nil", err: nil, want: false}, + {name: "io_eof", err: io.EOF, want: true}, + {name: "net_closed", err: net.ErrClosed, want: true}, + {name: "context_canceled", err: context.Canceled, want: true}, + {name: "ws_normal_closure", err: coderws.CloseError{Code: coderws.StatusNormalClosure}, want: true}, + {name: "ws_going_away", err: coderws.CloseError{Code: coderws.StatusGoingAway}, want: true}, + {name: "ws_no_status", err: coderws.CloseError{Code: coderws.StatusNoStatusRcvd}, want: true}, + {name: "ws_abnormal_1006", err: coderws.CloseError{Code: coderws.StatusAbnormalClosure}, want: true}, + {name: "ws_policy_violation", err: coderws.CloseError{Code: coderws.StatusPolicyViolation}, want: false}, + {name: "wrapped_eof_message", err: errors.New("failed to get reader: failed to read frame header: EOF"), want: true}, + {name: "connection_reset_by_peer", err: errors.New("failed to read frame header: read tcp 127.0.0.1:1234->127.0.0.1:5678: read: connection reset by peer"), want: true}, + {name: "broken_pipe", err: errors.New("write tcp 127.0.0.1:1234->127.0.0.1:5678: write: broken pipe"), want: true}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, isOpenAIWSClientDisconnectError(tt.err)) + }) + } +} + +func TestIsOpenAIWSIngressPreviousResponseNotFound(t *testing.T) { + t.Parallel() + + require.False(t, isOpenAIWSIngressPreviousResponseNotFound(nil)) + require.False(t, isOpenAIWSIngressPreviousResponseNotFound(errors.New("plain error"))) + require.False(t, isOpenAIWSIngressPreviousResponseNotFound( + wrapOpenAIWSIngressTurnError("read_upstream", errors.New("upstream read failed"), false), + )) + require.False(t, isOpenAIWSIngressPreviousResponseNotFound( + wrapOpenAIWSIngressTurnError(openAIWSIngressStagePreviousResponseNotFound, errors.New("previous response not found"), true), + )) + require.True(t, isOpenAIWSIngressPreviousResponseNotFound( + wrapOpenAIWSIngressTurnError(openAIWSIngressStagePreviousResponseNotFound, errors.New("previous response not found"), false), + )) +} + +func TestOpenAIWSIngressPreviousResponseRecoveryEnabled(t *testing.T) { + t.Parallel() + + var nilService *OpenAIGatewayService + require.True(t, nilService.openAIWSIngressPreviousResponseRecoveryEnabled(), "nil service should default to enabled") + + svcWithNilCfg := &OpenAIGatewayService{} + require.True(t, svcWithNilCfg.openAIWSIngressPreviousResponseRecoveryEnabled(), "nil config should default to enabled") + + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + } + require.False(t, svc.openAIWSIngressPreviousResponseRecoveryEnabled(), "explicit config default should be false") + + svc.cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true + require.True(t, svc.openAIWSIngressPreviousResponseRecoveryEnabled()) +} + +func TestDropPreviousResponseIDFromRawPayload(t *testing.T) { + t.Parallel() + + t.Run("empty_payload", func(t *testing.T) { + updated, removed, err := dropPreviousResponseIDFromRawPayload(nil) + require.NoError(t, err) + require.False(t, removed) + require.Empty(t, updated) + }) + + t.Run("payload_without_previous_response_id", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`) + updated, removed, err := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, err) + require.False(t, removed) + require.Equal(t, string(payload), string(updated)) + }) + + t.Run("normal_delete_success", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`) + updated, removed, err := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, err) + require.True(t, removed) + require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists()) + }) + + t.Run("duplicate_keys_are_removed", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_a","input":[],"previous_response_id":"resp_b"}`) + updated, removed, err := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, err) + require.True(t, removed) + require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists()) + }) + + t.Run("nil_delete_fn_uses_default_delete_logic", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`) + updated, removed, err := dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, nil) + require.NoError(t, err) + require.True(t, removed) + require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists()) + }) + + t.Run("delete_error", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`) + updated, removed, err := dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, func(_ []byte, _ string) ([]byte, error) { + return nil, errors.New("delete failed") + }) + require.Error(t, err) + require.False(t, removed) + require.Equal(t, string(payload), string(updated)) + }) + + t.Run("malformed_json_is_still_best_effort_deleted", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_abc"`) + require.True(t, gjson.GetBytes(payload, "previous_response_id").Exists()) + + updated, removed, err := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, err) + require.True(t, removed) + require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists()) + }) +} + +func TestAlignStoreDisabledPreviousResponseID(t *testing.T) { + t.Parallel() + + t.Run("empty_payload", func(t *testing.T) { + updated, changed, err := alignStoreDisabledPreviousResponseID(nil, "resp_target") + require.NoError(t, err) + require.False(t, changed) + require.Empty(t, updated) + }) + + t.Run("empty_expected", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_old"}`) + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "") + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, string(payload), string(updated)) + }) + + t.Run("missing_previous_response_id", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`) + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target") + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, string(payload), string(updated)) + }) + + t.Run("already_aligned", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_target"}`) + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target") + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String()) + }) + + t.Run("mismatch_rewrites_to_expected", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_old","input":[]}`) + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target") + require.NoError(t, err) + require.True(t, changed) + require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String()) + }) + + t.Run("duplicate_keys_rewrites_to_single_expected", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_old_1","input":[],"previous_response_id":"resp_old_2"}`) + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target") + require.NoError(t, err) + require.True(t, changed) + require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String()) + }) +} + +func TestSetPreviousResponseIDToRawPayload(t *testing.T) { + t.Parallel() + + t.Run("empty_payload", func(t *testing.T) { + updated, err := setPreviousResponseIDToRawPayload(nil, "resp_target") + require.NoError(t, err) + require.Empty(t, updated) + }) + + t.Run("empty_previous_response_id", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`) + updated, err := setPreviousResponseIDToRawPayload(payload, "") + require.NoError(t, err) + require.Equal(t, string(payload), string(updated)) + }) + + t.Run("set_previous_response_id_when_missing", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`) + updated, err := setPreviousResponseIDToRawPayload(payload, "resp_target") + require.NoError(t, err) + require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String()) + require.Equal(t, "gpt-5.1", gjson.GetBytes(updated, "model").String()) + }) + + t.Run("overwrite_existing_previous_response_id", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_old"}`) + updated, err := setPreviousResponseIDToRawPayload(payload, "resp_new") + require.NoError(t, err) + require.Equal(t, "resp_new", gjson.GetBytes(updated, "previous_response_id").String()) + }) +} + +func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + storeDisabled bool + turn int + hasFunctionCallOutput bool + currentPreviousResponse string + expectedPrevious string + want bool + }{ + { + name: "infer_when_all_conditions_match", + storeDisabled: true, + turn: 2, + hasFunctionCallOutput: true, + expectedPrevious: "resp_1", + want: true, + }, + { + name: "skip_when_store_enabled", + storeDisabled: false, + turn: 2, + hasFunctionCallOutput: true, + expectedPrevious: "resp_1", + want: false, + }, + { + name: "skip_on_first_turn", + storeDisabled: true, + turn: 1, + hasFunctionCallOutput: true, + expectedPrevious: "resp_1", + want: false, + }, + { + name: "skip_without_function_call_output", + storeDisabled: true, + turn: 2, + hasFunctionCallOutput: false, + expectedPrevious: "resp_1", + want: false, + }, + { + name: "skip_when_request_already_has_previous_response_id", + storeDisabled: true, + turn: 2, + hasFunctionCallOutput: true, + currentPreviousResponse: "resp_client", + expectedPrevious: "resp_1", + want: false, + }, + { + name: "skip_when_last_turn_response_id_missing", + storeDisabled: true, + turn: 2, + hasFunctionCallOutput: true, + expectedPrevious: "", + want: false, + }, + { + name: "trim_whitespace_before_judgement", + storeDisabled: true, + turn: 2, + hasFunctionCallOutput: true, + expectedPrevious: " resp_2 ", + want: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := shouldInferIngressFunctionCallOutputPreviousResponseID( + tt.storeDisabled, + tt.turn, + tt.hasFunctionCallOutput, + tt.currentPreviousResponse, + tt.expectedPrevious, + ) + require.Equal(t, tt.want, got) + }) + } +} + +func TestOpenAIWSInputIsPrefixExtended(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + previous []byte + current []byte + want bool + expectErr bool + }{ + { + name: "both_missing_input", + previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`), + current: []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_1"}`), + want: true, + }, + { + name: "previous_missing_current_empty_array", + previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`), + current: []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`), + want: true, + }, + { + name: "previous_missing_current_non_empty_array", + previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`), + current: []byte(`{"type":"response.create","model":"gpt-5.1","input":[{"type":"input_text","text":"hello"}]}`), + want: false, + }, + { + name: "array_prefix_match", + previous: []byte(`{"input":[{"type":"input_text","text":"hello"}]}`), + current: []byte(`{"input":[{"text":"hello","type":"input_text"},{"type":"input_text","text":"world"}]}`), + want: true, + }, + { + name: "array_prefix_mismatch", + previous: []byte(`{"input":[{"type":"input_text","text":"hello"}]}`), + current: []byte(`{"input":[{"type":"input_text","text":"different"}]}`), + want: false, + }, + { + name: "current_shorter_than_previous", + previous: []byte(`{"input":[{"type":"input_text","text":"a"},{"type":"input_text","text":"b"}]}`), + current: []byte(`{"input":[{"type":"input_text","text":"a"}]}`), + want: false, + }, + { + name: "previous_has_input_current_missing", + previous: []byte(`{"input":[{"type":"input_text","text":"a"}]}`), + current: []byte(`{"model":"gpt-5.1"}`), + want: false, + }, + { + name: "input_string_treated_as_single_item", + previous: []byte(`{"input":"hello"}`), + current: []byte(`{"input":"hello"}`), + want: true, + }, + { + name: "current_invalid_input_json", + previous: []byte(`{"input":[]}`), + current: []byte(`{"input":[}`), + expectErr: true, + }, + { + name: "invalid_input_json", + previous: []byte(`{"input":[}`), + current: []byte(`{"input":[]}`), + expectErr: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := openAIWSInputIsPrefixExtended(tt.previous, tt.current) + if tt.expectErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.want, got) + }) + } +} + +func TestNormalizeOpenAIWSJSONForCompare(t *testing.T) { + t.Parallel() + + normalized, err := normalizeOpenAIWSJSONForCompare([]byte(`{"b":2,"a":1}`)) + require.NoError(t, err) + require.Equal(t, `{"a":1,"b":2}`, string(normalized)) + + _, err = normalizeOpenAIWSJSONForCompare([]byte(" ")) + require.Error(t, err) + + _, err = normalizeOpenAIWSJSONForCompare([]byte(`{"a":`)) + require.Error(t, err) +} + +func TestNormalizeOpenAIWSJSONForCompareOrRaw(t *testing.T) { + t.Parallel() + + require.Equal(t, `{"a":1,"b":2}`, string(normalizeOpenAIWSJSONForCompareOrRaw([]byte(`{"b":2,"a":1}`)))) + require.Equal(t, `{"a":`, string(normalizeOpenAIWSJSONForCompareOrRaw([]byte(`{"a":`)))) +} + +func TestNormalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(t *testing.T) { + t.Parallel() + + normalized, err := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID( + []byte(`{"model":"gpt-5.1","input":[1],"previous_response_id":"resp_x","metadata":{"b":2,"a":1}}`), + ) + require.NoError(t, err) + require.False(t, gjson.GetBytes(normalized, "input").Exists()) + require.False(t, gjson.GetBytes(normalized, "previous_response_id").Exists()) + require.Equal(t, float64(1), gjson.GetBytes(normalized, "metadata.a").Float()) + + _, err = normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(nil) + require.Error(t, err) + + _, err = normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID([]byte(`[]`)) + require.Error(t, err) +} + +func TestOpenAIWSExtractNormalizedInputSequence(t *testing.T) { + t.Parallel() + + t.Run("empty_payload", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence(nil) + require.NoError(t, err) + require.False(t, exists) + require.Nil(t, items) + }) + + t.Run("input_missing", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"type":"response.create"}`)) + require.NoError(t, err) + require.False(t, exists) + require.Nil(t, items) + }) + + t.Run("input_array", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":[{"type":"input_text","text":"hello"}]}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + }) + + t.Run("input_object", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":{"type":"input_text","text":"hello"}}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + }) + + t.Run("input_string", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":"hello"}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, `"hello"`, string(items[0])) + }) + + t.Run("input_number", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":42}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, "42", string(items[0])) + }) + + t.Run("input_bool", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":true}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, "true", string(items[0])) + }) + + t.Run("input_null", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":null}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, "null", string(items[0])) + }) + + t.Run("input_invalid_array_json", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":[}`)) + require.Error(t, err) + require.True(t, exists) + require.Nil(t, items) + }) +} + +func TestShouldKeepIngressPreviousResponseID(t *testing.T) { + t.Parallel() + + previousPayload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "store":false, + "tools":[{"type":"function","name":"tool_a"}], + "input":[{"type":"input_text","text":"hello"}] + }`) + currentStrictPayload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "store":false, + "tools":[{"name":"tool_a","type":"function"}], + "previous_response_id":"resp_turn_1", + "input":[{"text":"hello","type":"input_text"},{"type":"input_text","text":"world"}] + }`) + + t.Run("strict_incremental_keep", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_1", false) + require.NoError(t, err) + require.True(t, keep) + require.Equal(t, "strict_incremental_ok", reason) + }) + + t.Run("missing_previous_response_id", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "missing_previous_response_id", reason) + }) + + t.Run("missing_last_turn_response_id", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "", false) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "missing_last_turn_response_id", reason) + }) + + t.Run("previous_response_id_mismatch", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_other", false) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "previous_response_id_mismatch", reason) + }) + + t.Run("missing_previous_turn_payload", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID(nil, currentStrictPayload, "resp_turn_1", false) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "missing_previous_turn_payload", reason) + }) + + t.Run("non_input_changed", func(t *testing.T) { + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1-mini", + "store":false, + "tools":[{"type":"function","name":"tool_a"}], + "previous_response_id":"resp_turn_1", + "input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"}] + }`) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "non_input_changed", reason) + }) + + t.Run("delta_input_keeps_previous_response_id", func(t *testing.T) { + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "store":false, + "tools":[{"type":"function","name":"tool_a"}], + "previous_response_id":"resp_turn_1", + "input":[{"type":"input_text","text":"different"}] + }`) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false) + require.NoError(t, err) + require.True(t, keep) + require.Equal(t, "strict_incremental_ok", reason) + }) + + t.Run("function_call_output_keeps_previous_response_id", func(t *testing.T) { + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "store":false, + "previous_response_id":"resp_external", + "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}] + }`) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", true) + require.NoError(t, err) + require.True(t, keep) + require.Equal(t, "has_function_call_output", reason) + }) + + t.Run("non_input_compare_error", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID([]byte(`[]`), currentStrictPayload, "resp_turn_1", false) + require.Error(t, err) + require.False(t, keep) + require.Equal(t, "non_input_compare_error", reason) + }) + + t.Run("current_payload_compare_error", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, []byte(`{"previous_response_id":"resp_turn_1","input":[}`), "resp_turn_1", false) + require.Error(t, err) + require.False(t, keep) + require.Equal(t, "non_input_compare_error", reason) + }) +} + +func TestBuildOpenAIWSReplayInputSequence(t *testing.T) { + t.Parallel() + + lastFull := []json.RawMessage{ + json.RawMessage(`{"type":"input_text","text":"hello"}`), + } + + t.Run("no_previous_response_id_use_current", func(t *testing.T) { + items, exists, err := buildOpenAIWSReplayInputSequence( + lastFull, + true, + []byte(`{"input":[{"type":"input_text","text":"new"}]}`), + false, + ) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, "new", gjson.GetBytes(items[0], "text").String()) + }) + + t.Run("previous_response_id_delta_append", func(t *testing.T) { + items, exists, err := buildOpenAIWSReplayInputSequence( + lastFull, + true, + []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"world"}]}`), + true, + ) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 2) + require.Equal(t, "hello", gjson.GetBytes(items[0], "text").String()) + require.Equal(t, "world", gjson.GetBytes(items[1], "text").String()) + }) + + t.Run("previous_response_id_full_input_replace", func(t *testing.T) { + items, exists, err := buildOpenAIWSReplayInputSequence( + lastFull, + true, + []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"}]}`), + true, + ) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 2) + require.Equal(t, "hello", gjson.GetBytes(items[0], "text").String()) + require.Equal(t, "world", gjson.GetBytes(items[1], "text").String()) + }) +} + +func TestSetOpenAIWSPayloadInputSequence(t *testing.T) { + t.Parallel() + + t.Run("set_items", func(t *testing.T) { + original := []byte(`{"type":"response.create","previous_response_id":"resp_1"}`) + items := []json.RawMessage{ + json.RawMessage(`{"type":"input_text","text":"hello"}`), + json.RawMessage(`{"type":"input_text","text":"world"}`), + } + updated, err := setOpenAIWSPayloadInputSequence(original, items, true) + require.NoError(t, err) + require.Equal(t, "hello", gjson.GetBytes(updated, "input.0.text").String()) + require.Equal(t, "world", gjson.GetBytes(updated, "input.1.text").String()) + }) + + t.Run("preserve_empty_array_not_null", func(t *testing.T) { + original := []byte(`{"type":"response.create","previous_response_id":"resp_1"}`) + updated, err := setOpenAIWSPayloadInputSequence(original, nil, true) + require.NoError(t, err) + require.True(t, gjson.GetBytes(updated, "input").IsArray()) + require.Len(t, gjson.GetBytes(updated, "input").Array(), 0) + require.False(t, gjson.GetBytes(updated, "input").Type == gjson.Null) + }) +} + +func TestCloneOpenAIWSRawMessages(t *testing.T) { + t.Parallel() + + t.Run("nil_slice", func(t *testing.T) { + cloned := cloneOpenAIWSRawMessages(nil) + require.Nil(t, cloned) + }) + + t.Run("empty_slice", func(t *testing.T) { + items := make([]json.RawMessage, 0) + cloned := cloneOpenAIWSRawMessages(items) + require.NotNil(t, cloned) + require.Len(t, cloned, 0) + }) +} diff --git a/backend/internal/service/openai_ws_forwarder_retry_payload_test.go b/backend/internal/service/openai_ws_forwarder_retry_payload_test.go new file mode 100644 index 00000000..0ea7e1c7 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_retry_payload_test.go @@ -0,0 +1,50 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestApplyOpenAIWSRetryPayloadStrategy_KeepPromptCacheKey(t *testing.T) { + payload := map[string]any{ + "model": "gpt-5.3-codex", + "prompt_cache_key": "pcache_123", + "include": []any{"reasoning.encrypted_content"}, + "text": map[string]any{ + "verbosity": "low", + }, + "tools": []any{map[string]any{"type": "function"}}, + } + + strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 3) + require.Equal(t, "trim_optional_fields", strategy) + require.Contains(t, removed, "include") + require.NotContains(t, removed, "prompt_cache_key") + require.Equal(t, "pcache_123", payload["prompt_cache_key"]) + require.NotContains(t, payload, "include") + require.Contains(t, payload, "text") +} + +func TestApplyOpenAIWSRetryPayloadStrategy_AttemptSixKeepsSemanticFields(t *testing.T) { + payload := map[string]any{ + "prompt_cache_key": "pcache_456", + "instructions": "long instructions", + "tools": []any{map[string]any{"type": "function"}}, + "parallel_tool_calls": true, + "tool_choice": "auto", + "include": []any{"reasoning.encrypted_content"}, + "text": map[string]any{"verbosity": "high"}, + } + + strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 6) + require.Equal(t, "trim_optional_fields", strategy) + require.Contains(t, removed, "include") + require.NotContains(t, removed, "prompt_cache_key") + require.Equal(t, "pcache_456", payload["prompt_cache_key"]) + require.Contains(t, payload, "instructions") + require.Contains(t, payload, "tools") + require.Contains(t, payload, "tool_choice") + require.Contains(t, payload, "parallel_tool_calls") + require.Contains(t, payload, "text") +} diff --git a/backend/internal/service/openai_ws_forwarder_success_test.go b/backend/internal/service/openai_ws_forwarder_success_test.go new file mode 100644 index 00000000..592801f6 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_success_test.go @@ -0,0 +1,1306 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestOpenAIGatewayService_Forward_WSv2_SuccessAndBindSticky(t *testing.T) { + gin.SetMode(gin.TestMode) + + type receivedPayload struct { + Type string + PreviousResponseID string + StreamExists bool + Stream bool + } + receivedCh := make(chan receivedPayload, 1) + + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var request map[string]any + if err := conn.ReadJSON(&request); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + requestJSON := requestToJSONString(request) + receivedCh <- receivedPayload{ + Type: strings.TrimSpace(gjson.Get(requestJSON, "type").String()), + PreviousResponseID: strings.TrimSpace(gjson.Get(requestJSON, "previous_response_id").String()), + StreamExists: gjson.Get(requestJSON, "stream").Exists(), + Stream: gjson.Get(requestJSON, "stream").Bool(), + } + + if err := conn.WriteJSON(map[string]any{ + "type": "response.created", + "response": map[string]any{ + "id": "resp_new_1", + "model": "gpt-5.1", + }, + }); err != nil { + t.Errorf("write response.created failed: %v", err) + return + } + if err := conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": "resp_new_1", + "model": "gpt-5.1", + "usage": map[string]any{ + "input_tokens": 12, + "output_tokens": 7, + "input_tokens_details": map[string]any{ + "cached_tokens": 3, + }, + }, + }, + }); err != nil { + t.Errorf("write response.completed failed: %v", err) + return + } + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "unit-test-agent/1.0") + groupID := int64(1001) + c.Set("api_key", &APIKey{GroupID: &groupID}) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 30 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 10 + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cache := &stubGatewayCache{} + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + cache: cache, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 9, + Name: "openai-ws", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 2, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_1","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 12, result.Usage.InputTokens) + require.Equal(t, 7, result.Usage.OutputTokens) + require.Equal(t, 3, result.Usage.CacheReadInputTokens) + require.Equal(t, "resp_new_1", result.RequestID) + require.True(t, result.OpenAIWSMode) + require.False(t, gjson.GetBytes(upstream.lastBody, "model").Exists(), "WSv2 成功时不应回落 HTTP 上游") + + received := <-receivedCh + require.Equal(t, "response.create", received.Type) + require.Equal(t, "resp_prev_1", received.PreviousResponseID) + require.True(t, received.StreamExists, "WS 请求应携带 stream 字段") + require.False(t, received.Stream, "应保持客户端 stream=false 的原始语义") + + store := svc.getOpenAIWSStateStore() + mappedAccountID, getErr := store.GetResponseAccount(context.Background(), groupID, "resp_new_1") + require.NoError(t, getErr) + require.Equal(t, account.ID, mappedAccountID) + connID, ok := store.GetResponseConn("resp_new_1") + require.True(t, ok) + require.NotEmpty(t, connID) + + responseBody := rec.Body.Bytes() + require.Equal(t, "resp_new_1", gjson.GetBytes(responseBody, "id").String()) +} + +func requestToJSONString(payload map[string]any) string { + if len(payload) == 0 { + return "{}" + } + b, err := json.Marshal(payload) + if err != nil { + return "{}" + } + return string(b) +} + +func TestLogOpenAIWSBindResponseAccountWarn(t *testing.T) { + require.NotPanics(t, func() { + logOpenAIWSBindResponseAccountWarn(1, 2, "resp_ok", nil) + }) + require.NotPanics(t, func() { + logOpenAIWSBindResponseAccountWarn(1, 2, "resp_err", errors.New("bind failed")) + }) +} + +func TestOpenAIGatewayService_Forward_WSv2_RewriteModelAndToolCallsOnCompletedEvent(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") + groupID := int64(3001) + c.Set("api_key", &APIKey{GroupID: &groupID}) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_model_tool_1","model":"gpt-5.1","tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}],"usage":{"input_tokens":2,"output_tokens":1}},"tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}]}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 1301, + Name: "openai-rewrite", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "model_mapping": map[string]any{ + "custom-original-model": "gpt-5.1", + }, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"custom-original-model","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_model_tool_1", result.RequestID) + require.Equal(t, "custom-original-model", gjson.GetBytes(rec.Body.Bytes(), "model").String(), "响应模型应回写为原始请求模型") + require.Equal(t, "edit", gjson.GetBytes(rec.Body.Bytes(), "tool_calls.0.function.name").String(), "工具名称应被修正为 OpenCode 规范") +} + +func TestOpenAIWSPayloadString_OnlyAcceptsStringValues(t *testing.T) { + payload := map[string]any{ + "type": nil, + "model": 123, + "prompt_cache_key": " cache-key ", + "previous_response_id": []byte(" resp_1 "), + } + + require.Equal(t, "", openAIWSPayloadString(payload, "type")) + require.Equal(t, "", openAIWSPayloadString(payload, "model")) + require.Equal(t, "cache-key", openAIWSPayloadString(payload, "prompt_cache_key")) + require.Equal(t, "resp_1", openAIWSPayloadString(payload, "previous_response_id")) +} + +func TestOpenAIGatewayService_Forward_WSv2_PoolReuseNotOneToOne(t *testing.T) { + gin.SetMode(gin.TestMode) + + var upgradeCount atomic.Int64 + var sequence atomic.Int64 + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgradeCount.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + for { + var request map[string]any + if err := conn.ReadJSON(&request); err != nil { + return + } + idx := sequence.Add(1) + responseID := "resp_reuse_" + strconv.FormatInt(idx, 10) + if err := conn.WriteJSON(map[string]any{ + "type": "response.created", + "response": map[string]any{ + "id": responseID, + "model": "gpt-5.1", + }, + }); err != nil { + return + } + if err := conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": responseID, + "model": "gpt-5.1", + "usage": map[string]any{ + "input_tokens": 2, + "output_tokens": 1, + }, + }, + }); err != nil { + return + } + } + })) + defer wsServer.Close() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 30 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 10 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + account := &Account{ + ID: 19, + Name: "openai-ws", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 2, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + for i := 0; i < 2; i++ { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") + groupID := int64(2001) + c.Set("api_key", &APIKey{GroupID: &groupID}) + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_reuse","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, strings.HasPrefix(result.RequestID, "resp_reuse_")) + } + + require.Equal(t, int64(1), upgradeCount.Load(), "多个客户端请求应复用账号连接池而不是 1:1 对等建链") + metrics := svc.SnapshotOpenAIWSPoolMetrics() + require.GreaterOrEqual(t, metrics.AcquireReuseTotal, int64(1)) + require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1)) +} + +func TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") + c.Request.Header.Set("session_id", "sess-oauth-1") + c.Request.Header.Set("conversation_id", "conv-oauth-1") + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.AllowStoreRecovery = false + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_oauth_1","model":"gpt-5.1","usage":{"input_tokens":3,"output_tokens":2}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + account := &Account{ + ID: 29, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token-1", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"store":true,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_oauth_1", result.RequestID) + + require.NotNil(t, captureConn.lastWrite) + requestJSON := requestToJSONString(captureConn.lastWrite) + require.True(t, gjson.Get(requestJSON, "store").Exists(), "OAuth WSv2 应显式写入 store 字段") + require.False(t, gjson.Get(requestJSON, "store").Bool(), "默认策略应将 OAuth store 置为 false") + require.True(t, gjson.Get(requestJSON, "stream").Exists(), "WSv2 payload 应保留 stream 字段") + require.True(t, gjson.Get(requestJSON, "stream").Bool(), "OAuth Codex 规范化后应强制 stream=true") + require.Equal(t, openAIWSBetaV2Value, captureDialer.lastHeaders.Get("OpenAI-Beta")) + require.Equal(t, "sess-oauth-1", captureDialer.lastHeaders.Get("session_id")) + require.Equal(t, "conv-oauth-1", captureDialer.lastHeaders.Get("conversation_id")) +} + +func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheKey(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_prompt_cache_key","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + account := &Account{ + ID: 31, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token-1", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":true,"prompt_cache_key":"pcache_123","input":[{"type":"input_text","text":"hi"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_prompt_cache_key", result.RequestID) + + require.Equal(t, "pcache_123", captureDialer.lastHeaders.Get("session_id")) + require.Empty(t, captureDialer.lastHeaders.Get("conversation_id")) + require.NotNil(t, captureConn.lastWrite) + require.True(t, gjson.Get(requestToJSONString(captureConn.lastWrite), "stream").Exists()) +} + +func TestOpenAIGatewayService_Forward_WSv1_Unsupported(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsockets = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 39, + Name: "openai-ws-v1", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://api.openai.com/v1/responses", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_v1","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "ws v1") + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "WSv1") + require.Nil(t, upstream.lastReq, "WSv1 不支持时不应触发 HTTP 上游请求") +} + +func TestOpenAIGatewayService_Forward_WSv2_TurnStateAndMetadataReplayOnReconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + + var connIndex atomic.Int64 + headersCh := make(chan http.Header, 4) + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + idx := connIndex.Add(1) + headersCh <- cloneHeader(r.Header) + + respHeader := http.Header{} + if idx == 1 { + respHeader.Set("x-codex-turn-state", "turn_state_first") + } + conn, err := upgrader.Upgrade(w, r, respHeader) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var request map[string]any + if err := conn.ReadJSON(&request); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + responseID := "resp_turn_" + strconv.FormatInt(idx, 10) + if err := conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": responseID, + "model": "gpt-5.1", + "usage": map[string]any{ + "input_tokens": 2, + "output_tokens": 1, + }, + }, + }); err != nil { + t.Errorf("write response.completed failed: %v", err) + return + } + })) + defer wsServer.Close() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 0 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 49, + Name: "openai-turn-state", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + reqBody := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + rec1 := httptest.NewRecorder() + c1, _ := gin.CreateTestContext(rec1) + c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c1.Request.Header.Set("session_id", "session_turn_state") + c1.Request.Header.Set("x-codex-turn-metadata", "turn_meta_1") + result1, err := svc.Forward(context.Background(), c1, account, reqBody) + require.NoError(t, err) + require.NotNil(t, result1) + + sessionHash := svc.GenerateSessionHash(c1, reqBody) + store := svc.getOpenAIWSStateStore() + turnState, ok := store.GetSessionTurnState(0, sessionHash) + require.True(t, ok) + require.Equal(t, "turn_state_first", turnState) + + // 主动淘汰连接,模拟下一次请求发生重连。 + connID, hasConn := store.GetResponseConn(result1.RequestID) + require.True(t, hasConn) + svc.getOpenAIWSConnPool().evictConn(account.ID, connID) + + rec2 := httptest.NewRecorder() + c2, _ := gin.CreateTestContext(rec2) + c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c2.Request.Header.Set("session_id", "session_turn_state") + c2.Request.Header.Set("x-codex-turn-metadata", "turn_meta_2") + result2, err := svc.Forward(context.Background(), c2, account, reqBody) + require.NoError(t, err) + require.NotNil(t, result2) + + firstHandshakeHeaders := <-headersCh + secondHandshakeHeaders := <-headersCh + require.Equal(t, "turn_meta_1", firstHandshakeHeaders.Get("X-Codex-Turn-Metadata")) + require.Equal(t, "turn_meta_2", secondHandshakeHeaders.Get("X-Codex-Turn-Metadata")) + require.Equal(t, "turn_state_first", secondHandshakeHeaders.Get("X-Codex-Turn-State")) +} + +func TestOpenAIGatewayService_Forward_WSv2_GeneratePrewarm(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("session_id", "session-prewarm") + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_prewarm_1","model":"gpt-5.1","usage":{"input_tokens":0,"output_tokens":0}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_main_1","model":"gpt-5.1","usage":{"input_tokens":4,"output_tokens":2}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 59, + Name: "openai-prewarm", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_main_1", result.RequestID) + + require.Len(t, captureConn.writes, 2, "开启 generate=false 预热后应发送两次 WS 请求") + firstWrite := requestToJSONString(captureConn.writes[0]) + secondWrite := requestToJSONString(captureConn.writes[1]) + require.True(t, gjson.Get(firstWrite, "generate").Exists()) + require.False(t, gjson.Get(firstWrite, "generate").Bool()) + require.False(t, gjson.Get(secondWrite, "generate").Exists()) +} + +func TestOpenAIGatewayService_PrewarmReadHonorsParentContext(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled = true + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + svc := &OpenAIGatewayService{ + cfg: cfg, + toolCorrector: NewCodexToolCorrector(), + } + account := &Account{ + ID: 601, + Name: "openai-prewarm-timeout", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + } + conn := newOpenAIWSConn("prewarm_ctx_conn", account.ID, &openAIWSBlockingConn{ + readDelay: 200 * time.Millisecond, + }, nil) + lease := &openAIWSConnLease{ + accountID: account.ID, + conn: conn, + } + payload := map[string]any{ + "type": "response.create", + "model": "gpt-5.1", + } + + ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) + defer cancel() + start := time.Now() + err := svc.performOpenAIWSGeneratePrewarm( + ctx, + lease, + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, + payload, + "", + map[string]any{"model": "gpt-5.1"}, + account, + nil, + 0, + ) + elapsed := time.Since(start) + require.Error(t, err) + require.Contains(t, err.Error(), "prewarm_read_event") + require.Less(t, elapsed, 180*time.Millisecond, "预热读取应受父 context 取消控制,不应阻塞到 read_timeout") +} + +func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_meta_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_meta_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 69, + Name: "openai-turn-metadata", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + + rec1 := httptest.NewRecorder() + c1, _ := gin.CreateTestContext(rec1) + c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c1.Request.Header.Set("session_id", "session-metadata-reuse") + c1.Request.Header.Set("x-codex-turn-metadata", "turn_meta_payload_1") + result1, err := svc.Forward(context.Background(), c1, account, body) + require.NoError(t, err) + require.NotNil(t, result1) + require.Equal(t, "resp_meta_1", result1.RequestID) + + rec2 := httptest.NewRecorder() + c2, _ := gin.CreateTestContext(rec2) + c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c2.Request.Header.Set("session_id", "session-metadata-reuse") + c2.Request.Header.Set("x-codex-turn-metadata", "turn_meta_payload_2") + result2, err := svc.Forward(context.Background(), c2, account, body) + require.NoError(t, err) + require.NotNil(t, result2) + require.Equal(t, "resp_meta_2", result2.RequestID) + + require.Equal(t, 1, captureDialer.DialCount(), "同一账号两轮请求应复用同一 WS 连接") + require.Len(t, captureConn.writes, 2) + + firstWrite := requestToJSONString(captureConn.writes[0]) + secondWrite := requestToJSONString(captureConn.writes[1]) + require.Equal(t, "turn_meta_payload_1", gjson.Get(firstWrite, "client_metadata.x-codex-turn-metadata").String()) + require.Equal(t, "turn_meta_payload_2", gjson.Get(secondWrite, "client_metadata.x-codex-turn-metadata").String()) +} + +func TestOpenAIGatewayService_Forward_WSv2StoreFalseSessionConnIsolation(t *testing.T) { + gin.SetMode(gin.TestMode) + + var upgradeCount atomic.Int64 + var sequence atomic.Int64 + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgradeCount.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + for { + var request map[string]any + if err := conn.ReadJSON(&request); err != nil { + return + } + responseID := "resp_store_false_" + strconv.FormatInt(sequence.Add(1), 10) + if err := conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": responseID, + "model": "gpt-5.1", + "usage": map[string]any{ + "input_tokens": 1, + "output_tokens": 1, + }, + }, + }); err != nil { + return + } + } + })) + defer wsServer.Close() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4 + cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 79, + Name: "openai-store-false", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 2, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + + rec1 := httptest.NewRecorder() + c1, _ := gin.CreateTestContext(rec1) + c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c1.Request.Header.Set("session_id", "session_store_false_a") + result1, err := svc.Forward(context.Background(), c1, account, body) + require.NoError(t, err) + require.NotNil(t, result1) + require.Equal(t, int64(1), upgradeCount.Load()) + + rec2 := httptest.NewRecorder() + c2, _ := gin.CreateTestContext(rec2) + c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c2.Request.Header.Set("session_id", "session_store_false_a") + result2, err := svc.Forward(context.Background(), c2, account, body) + require.NoError(t, err) + require.NotNil(t, result2) + require.Equal(t, int64(1), upgradeCount.Load(), "同一 session(store=false) 应复用同一 WS 连接") + + rec3 := httptest.NewRecorder() + c3, _ := gin.CreateTestContext(rec3) + c3.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c3.Request.Header.Set("session_id", "session_store_false_b") + result3, err := svc.Forward(context.Background(), c3, account, body) + require.NoError(t, err) + require.NotNil(t, result3) + require.Equal(t, int64(2), upgradeCount.Load(), "不同 session(store=false) 应隔离连接,避免续链状态互相覆盖") +} + +func TestOpenAIGatewayService_Forward_WSv2StoreFalseDisableForceNewConnAllowsReuse(t *testing.T) { + gin.SetMode(gin.TestMode) + + var upgradeCount atomic.Int64 + var sequence atomic.Int64 + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgradeCount.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + for { + var request map[string]any + if err := conn.ReadJSON(&request); err != nil { + return + } + responseID := "resp_store_false_reuse_" + strconv.FormatInt(sequence.Add(1), 10) + if err := conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": responseID, + "model": "gpt-5.1", + "usage": map[string]any{ + "input_tokens": 1, + "output_tokens": 1, + }, + }, + }); err != nil { + return + } + } + })) + defer wsServer.Close() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = false + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 80, + Name: "openai-store-false-reuse", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 2, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + + rec1 := httptest.NewRecorder() + c1, _ := gin.CreateTestContext(rec1) + c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c1.Request.Header.Set("session_id", "session_store_false_reuse_a") + result1, err := svc.Forward(context.Background(), c1, account, body) + require.NoError(t, err) + require.NotNil(t, result1) + require.Equal(t, int64(1), upgradeCount.Load()) + + rec2 := httptest.NewRecorder() + c2, _ := gin.CreateTestContext(rec2) + c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c2.Request.Header.Set("session_id", "session_store_false_reuse_b") + result2, err := svc.Forward(context.Background(), c2, account, body) + require.NoError(t, err) + require.NotNil(t, result2) + require.Equal(t, int64(1), upgradeCount.Load(), "关闭强制新连后,不同 session(store=false) 可复用连接") +} + +func TestOpenAIGatewayService_Forward_WSv2ReadTimeoutAppliesPerRead(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 1 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + readDelays: []time.Duration{ + 700 * time.Millisecond, + 700 * time.Millisecond, + }, + events: [][]byte{ + []byte(`{"type":"response.created","response":{"id":"resp_timeout_ok","model":"gpt-5.1"}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_timeout_ok","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_fallback","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 81, + Name: "openai-read-timeout", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_timeout_ok", result.RequestID) + require.Nil(t, upstream.lastReq, "每次 Read 都应独立应用超时;总时长超过 read_timeout 不应误回退 HTTP") +} + +type openAIWSCaptureDialer struct { + mu sync.Mutex + conn *openAIWSCaptureConn + lastHeaders http.Header + handshake http.Header + dialCount int +} + +func (d *openAIWSCaptureDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = proxyURL + d.mu.Lock() + d.lastHeaders = cloneHeader(headers) + d.dialCount++ + respHeaders := cloneHeader(d.handshake) + d.mu.Unlock() + return d.conn, 0, respHeaders, nil +} + +func (d *openAIWSCaptureDialer) DialCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.dialCount +} + +type openAIWSCaptureConn struct { + mu sync.Mutex + readDelays []time.Duration + events [][]byte + lastWrite map[string]any + writes []map[string]any + closed bool +} + +func (c *openAIWSCaptureConn) WriteJSON(ctx context.Context, value any) error { + _ = ctx + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return errOpenAIWSConnClosed + } + switch payload := value.(type) { + case map[string]any: + c.lastWrite = cloneMapStringAny(payload) + c.writes = append(c.writes, cloneMapStringAny(payload)) + case json.RawMessage: + var parsed map[string]any + if err := json.Unmarshal(payload, &parsed); err == nil { + c.lastWrite = cloneMapStringAny(parsed) + c.writes = append(c.writes, cloneMapStringAny(parsed)) + } + case []byte: + var parsed map[string]any + if err := json.Unmarshal(payload, &parsed); err == nil { + c.lastWrite = cloneMapStringAny(parsed) + c.writes = append(c.writes, cloneMapStringAny(parsed)) + } + } + return nil +} + +func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) { + if ctx == nil { + ctx = context.Background() + } + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil, errOpenAIWSConnClosed + } + if len(c.events) == 0 { + c.mu.Unlock() + return nil, io.EOF + } + delay := time.Duration(0) + if len(c.readDelays) > 0 { + delay = c.readDelays[0] + c.readDelays = c.readDelays[1:] + } + event := c.events[0] + c.events = c.events[1:] + c.mu.Unlock() + if delay > 0 { + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + } + } + return event, nil +} + +func (c *openAIWSCaptureConn) Ping(ctx context.Context) error { + _ = ctx + return nil +} + +func (c *openAIWSCaptureConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true + return nil +} + +func cloneMapStringAny(src map[string]any) map[string]any { + if src == nil { + return nil + } + dst := make(map[string]any, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} diff --git a/backend/internal/service/openai_ws_pool.go b/backend/internal/service/openai_ws_pool.go new file mode 100644 index 00000000..db6a96a7 --- /dev/null +++ b/backend/internal/service/openai_ws_pool.go @@ -0,0 +1,1706 @@ +package service + +import ( + "context" + "errors" + "fmt" + "math" + "net/http" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "golang.org/x/sync/errgroup" +) + +const ( + openAIWSConnMaxAge = 60 * time.Minute + openAIWSConnHealthCheckIdle = 90 * time.Second + openAIWSConnHealthCheckTO = 2 * time.Second + openAIWSConnPrewarmExtraDelay = 2 * time.Second + openAIWSAcquireCleanupInterval = 3 * time.Second + openAIWSBackgroundPingInterval = 30 * time.Second + openAIWSBackgroundSweepTicker = 30 * time.Second + + openAIWSPrewarmFailureWindow = 30 * time.Second + openAIWSPrewarmFailureSuppress = 2 +) + +var ( + errOpenAIWSConnClosed = errors.New("openai ws connection closed") + errOpenAIWSConnQueueFull = errors.New("openai ws connection queue full") + errOpenAIWSPreferredConnUnavailable = errors.New("openai ws preferred connection unavailable") +) + +type openAIWSDialError struct { + StatusCode int + ResponseHeaders http.Header + Err error +} + +func (e *openAIWSDialError) Error() string { + if e == nil { + return "" + } + if e.StatusCode > 0 { + return fmt.Sprintf("openai ws dial failed: status=%d err=%v", e.StatusCode, e.Err) + } + return fmt.Sprintf("openai ws dial failed: %v", e.Err) +} + +func (e *openAIWSDialError) Unwrap() error { + if e == nil { + return nil + } + return e.Err +} + +type openAIWSAcquireRequest struct { + Account *Account + WSURL string + Headers http.Header + ProxyURL string + PreferredConnID string + // ForceNewConn: 强制本次获取新连接(避免复用导致连接内续链状态互相污染)。 + ForceNewConn bool + // ForcePreferredConn: 强制本次只使用 PreferredConnID,禁止漂移到其它连接。 + ForcePreferredConn bool +} + +type openAIWSConnLease struct { + pool *openAIWSConnPool + accountID int64 + conn *openAIWSConn + queueWait time.Duration + connPick time.Duration + reused bool + released atomic.Bool +} + +func (l *openAIWSConnLease) activeConn() (*openAIWSConn, error) { + if l == nil || l.conn == nil { + return nil, errOpenAIWSConnClosed + } + if l.released.Load() { + return nil, errOpenAIWSConnClosed + } + return l.conn, nil +} + +func (l *openAIWSConnLease) ConnID() string { + if l == nil || l.conn == nil { + return "" + } + return l.conn.id +} + +func (l *openAIWSConnLease) QueueWaitDuration() time.Duration { + if l == nil { + return 0 + } + return l.queueWait +} + +func (l *openAIWSConnLease) ConnPickDuration() time.Duration { + if l == nil { + return 0 + } + return l.connPick +} + +func (l *openAIWSConnLease) Reused() bool { + if l == nil { + return false + } + return l.reused +} + +func (l *openAIWSConnLease) HandshakeHeader(name string) string { + if l == nil || l.conn == nil { + return "" + } + return l.conn.handshakeHeader(name) +} + +func (l *openAIWSConnLease) IsPrewarmed() bool { + if l == nil || l.conn == nil { + return false + } + return l.conn.isPrewarmed() +} + +func (l *openAIWSConnLease) MarkPrewarmed() { + if l == nil || l.conn == nil { + return + } + l.conn.markPrewarmed() +} + +func (l *openAIWSConnLease) WriteJSON(value any, timeout time.Duration) error { + conn, err := l.activeConn() + if err != nil { + return err + } + return conn.writeJSONWithTimeout(context.Background(), value, timeout) +} + +func (l *openAIWSConnLease) WriteJSONWithContextTimeout(ctx context.Context, value any, timeout time.Duration) error { + conn, err := l.activeConn() + if err != nil { + return err + } + return conn.writeJSONWithTimeout(ctx, value, timeout) +} + +func (l *openAIWSConnLease) WriteJSONContext(ctx context.Context, value any) error { + conn, err := l.activeConn() + if err != nil { + return err + } + return conn.writeJSON(value, ctx) +} + +func (l *openAIWSConnLease) ReadMessage(timeout time.Duration) ([]byte, error) { + conn, err := l.activeConn() + if err != nil { + return nil, err + } + return conn.readMessageWithTimeout(timeout) +} + +func (l *openAIWSConnLease) ReadMessageContext(ctx context.Context) ([]byte, error) { + conn, err := l.activeConn() + if err != nil { + return nil, err + } + return conn.readMessage(ctx) +} + +func (l *openAIWSConnLease) ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error) { + conn, err := l.activeConn() + if err != nil { + return nil, err + } + return conn.readMessageWithContextTimeout(ctx, timeout) +} + +func (l *openAIWSConnLease) PingWithTimeout(timeout time.Duration) error { + conn, err := l.activeConn() + if err != nil { + return err + } + return conn.pingWithTimeout(timeout) +} + +func (l *openAIWSConnLease) MarkBroken() { + if l == nil || l.pool == nil || l.conn == nil || l.released.Load() { + return + } + l.pool.evictConn(l.accountID, l.conn.id) +} + +func (l *openAIWSConnLease) Release() { + if l == nil || l.conn == nil { + return + } + if !l.released.CompareAndSwap(false, true) { + return + } + l.conn.release() +} + +type openAIWSConn struct { + id string + ws openAIWSClientConn + + handshakeHeaders http.Header + + leaseCh chan struct{} + closedCh chan struct{} + closeOnce sync.Once + + readMu sync.Mutex + writeMu sync.Mutex + + waiters atomic.Int32 + createdAtNano atomic.Int64 + lastUsedNano atomic.Int64 + prewarmed atomic.Bool +} + +func newOpenAIWSConn(id string, _ int64, ws openAIWSClientConn, handshakeHeaders http.Header) *openAIWSConn { + now := time.Now() + conn := &openAIWSConn{ + id: id, + ws: ws, + handshakeHeaders: cloneHeader(handshakeHeaders), + leaseCh: make(chan struct{}, 1), + closedCh: make(chan struct{}), + } + conn.leaseCh <- struct{}{} + conn.createdAtNano.Store(now.UnixNano()) + conn.lastUsedNano.Store(now.UnixNano()) + return conn +} + +func (c *openAIWSConn) tryAcquire() bool { + if c == nil { + return false + } + select { + case <-c.closedCh: + return false + default: + } + select { + case <-c.leaseCh: + select { + case <-c.closedCh: + c.release() + return false + default: + } + return true + default: + return false + } +} + +func (c *openAIWSConn) acquire(ctx context.Context) error { + if c == nil { + return errOpenAIWSConnClosed + } + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.closedCh: + return errOpenAIWSConnClosed + case <-c.leaseCh: + select { + case <-c.closedCh: + c.release() + return errOpenAIWSConnClosed + default: + } + return nil + } + } +} + +func (c *openAIWSConn) release() { + if c == nil { + return + } + select { + case c.leaseCh <- struct{}{}: + default: + } + c.touch() +} + +func (c *openAIWSConn) close() { + if c == nil { + return + } + c.closeOnce.Do(func() { + close(c.closedCh) + if c.ws != nil { + _ = c.ws.Close() + } + select { + case c.leaseCh <- struct{}{}: + default: + } + }) +} + +func (c *openAIWSConn) writeJSONWithTimeout(parent context.Context, value any, timeout time.Duration) error { + if c == nil { + return errOpenAIWSConnClosed + } + select { + case <-c.closedCh: + return errOpenAIWSConnClosed + default: + } + + writeCtx := parent + if writeCtx == nil { + writeCtx = context.Background() + } + if timeout <= 0 { + return c.writeJSON(value, writeCtx) + } + var cancel context.CancelFunc + writeCtx, cancel = context.WithTimeout(writeCtx, timeout) + defer cancel() + return c.writeJSON(value, writeCtx) +} + +func (c *openAIWSConn) writeJSON(value any, writeCtx context.Context) error { + c.writeMu.Lock() + defer c.writeMu.Unlock() + if c.ws == nil { + return errOpenAIWSConnClosed + } + if writeCtx == nil { + writeCtx = context.Background() + } + if err := c.ws.WriteJSON(writeCtx, value); err != nil { + return err + } + c.touch() + return nil +} + +func (c *openAIWSConn) readMessageWithTimeout(timeout time.Duration) ([]byte, error) { + return c.readMessageWithContextTimeout(context.Background(), timeout) +} + +func (c *openAIWSConn) readMessageWithContextTimeout(parent context.Context, timeout time.Duration) ([]byte, error) { + if c == nil { + return nil, errOpenAIWSConnClosed + } + select { + case <-c.closedCh: + return nil, errOpenAIWSConnClosed + default: + } + + if parent == nil { + parent = context.Background() + } + if timeout <= 0 { + return c.readMessage(parent) + } + readCtx, cancel := context.WithTimeout(parent, timeout) + defer cancel() + return c.readMessage(readCtx) +} + +func (c *openAIWSConn) readMessage(readCtx context.Context) ([]byte, error) { + c.readMu.Lock() + defer c.readMu.Unlock() + if c.ws == nil { + return nil, errOpenAIWSConnClosed + } + if readCtx == nil { + readCtx = context.Background() + } + payload, err := c.ws.ReadMessage(readCtx) + if err != nil { + return nil, err + } + c.touch() + return payload, nil +} + +func (c *openAIWSConn) pingWithTimeout(timeout time.Duration) error { + if c == nil { + return errOpenAIWSConnClosed + } + select { + case <-c.closedCh: + return errOpenAIWSConnClosed + default: + } + + c.writeMu.Lock() + defer c.writeMu.Unlock() + if c.ws == nil { + return errOpenAIWSConnClosed + } + if timeout <= 0 { + timeout = openAIWSConnHealthCheckTO + } + pingCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + if err := c.ws.Ping(pingCtx); err != nil { + return err + } + return nil +} + +func (c *openAIWSConn) touch() { + if c == nil { + return + } + c.lastUsedNano.Store(time.Now().UnixNano()) +} + +func (c *openAIWSConn) createdAt() time.Time { + if c == nil { + return time.Time{} + } + nano := c.createdAtNano.Load() + if nano <= 0 { + return time.Time{} + } + return time.Unix(0, nano) +} + +func (c *openAIWSConn) lastUsedAt() time.Time { + if c == nil { + return time.Time{} + } + nano := c.lastUsedNano.Load() + if nano <= 0 { + return time.Time{} + } + return time.Unix(0, nano) +} + +func (c *openAIWSConn) idleDuration(now time.Time) time.Duration { + if c == nil { + return 0 + } + last := c.lastUsedAt() + if last.IsZero() { + return 0 + } + return now.Sub(last) +} + +func (c *openAIWSConn) age(now time.Time) time.Duration { + if c == nil { + return 0 + } + created := c.createdAt() + if created.IsZero() { + return 0 + } + return now.Sub(created) +} + +func (c *openAIWSConn) isLeased() bool { + if c == nil { + return false + } + return len(c.leaseCh) == 0 +} + +func (c *openAIWSConn) handshakeHeader(name string) string { + if c == nil || c.handshakeHeaders == nil { + return "" + } + return strings.TrimSpace(c.handshakeHeaders.Get(strings.TrimSpace(name))) +} + +func (c *openAIWSConn) isPrewarmed() bool { + if c == nil { + return false + } + return c.prewarmed.Load() +} + +func (c *openAIWSConn) markPrewarmed() { + if c == nil { + return + } + c.prewarmed.Store(true) +} + +type openAIWSAccountPool struct { + mu sync.Mutex + conns map[string]*openAIWSConn + pinnedConns map[string]int + creating int + lastCleanupAt time.Time + lastAcquire *openAIWSAcquireRequest + prewarmActive bool + prewarmUntil time.Time + prewarmFails int + prewarmFailAt time.Time +} + +type OpenAIWSPoolMetricsSnapshot struct { + AcquireTotal int64 + AcquireReuseTotal int64 + AcquireCreateTotal int64 + AcquireQueueWaitTotal int64 + AcquireQueueWaitMsTotal int64 + ConnPickTotal int64 + ConnPickMsTotal int64 + ScaleUpTotal int64 + ScaleDownTotal int64 +} + +type openAIWSPoolMetrics struct { + acquireTotal atomic.Int64 + acquireReuseTotal atomic.Int64 + acquireCreateTotal atomic.Int64 + acquireQueueWaitTotal atomic.Int64 + acquireQueueWaitMs atomic.Int64 + connPickTotal atomic.Int64 + connPickMs atomic.Int64 + scaleUpTotal atomic.Int64 + scaleDownTotal atomic.Int64 +} + +type openAIWSConnPool struct { + cfg *config.Config + // 通过接口解耦底层 WS 客户端实现,默认使用 coder/websocket。 + clientDialer openAIWSClientDialer + + accounts sync.Map // key: int64(accountID), value: *openAIWSAccountPool + seq atomic.Uint64 + + metrics openAIWSPoolMetrics + + workerStopCh chan struct{} + workerWg sync.WaitGroup + closeOnce sync.Once +} + +func newOpenAIWSConnPool(cfg *config.Config) *openAIWSConnPool { + pool := &openAIWSConnPool{ + cfg: cfg, + clientDialer: newDefaultOpenAIWSClientDialer(), + workerStopCh: make(chan struct{}), + } + pool.startBackgroundWorkers() + return pool +} + +func (p *openAIWSConnPool) SnapshotMetrics() OpenAIWSPoolMetricsSnapshot { + if p == nil { + return OpenAIWSPoolMetricsSnapshot{} + } + return OpenAIWSPoolMetricsSnapshot{ + AcquireTotal: p.metrics.acquireTotal.Load(), + AcquireReuseTotal: p.metrics.acquireReuseTotal.Load(), + AcquireCreateTotal: p.metrics.acquireCreateTotal.Load(), + AcquireQueueWaitTotal: p.metrics.acquireQueueWaitTotal.Load(), + AcquireQueueWaitMsTotal: p.metrics.acquireQueueWaitMs.Load(), + ConnPickTotal: p.metrics.connPickTotal.Load(), + ConnPickMsTotal: p.metrics.connPickMs.Load(), + ScaleUpTotal: p.metrics.scaleUpTotal.Load(), + ScaleDownTotal: p.metrics.scaleDownTotal.Load(), + } +} + +func (p *openAIWSConnPool) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot { + if p == nil { + return OpenAIWSTransportMetricsSnapshot{} + } + if dialer, ok := p.clientDialer.(openAIWSTransportMetricsDialer); ok { + return dialer.SnapshotTransportMetrics() + } + return OpenAIWSTransportMetricsSnapshot{} +} + +func (p *openAIWSConnPool) setClientDialerForTest(dialer openAIWSClientDialer) { + if p == nil || dialer == nil { + return + } + p.clientDialer = dialer +} + +// Close 停止后台 worker 并关闭所有空闲连接,应在优雅关闭时调用。 +func (p *openAIWSConnPool) Close() { + if p == nil { + return + } + p.closeOnce.Do(func() { + if p.workerStopCh != nil { + close(p.workerStopCh) + } + p.workerWg.Wait() + // 遍历所有账户池,关闭全部空闲连接。 + p.accounts.Range(func(key, value any) bool { + ap, ok := value.(*openAIWSAccountPool) + if !ok || ap == nil { + return true + } + ap.mu.Lock() + for _, conn := range ap.conns { + if conn != nil && !conn.isLeased() { + conn.close() + } + } + ap.mu.Unlock() + return true + }) + }) +} + +func (p *openAIWSConnPool) startBackgroundWorkers() { + if p == nil || p.workerStopCh == nil { + return + } + p.workerWg.Add(2) + go func() { + defer p.workerWg.Done() + p.runBackgroundPingWorker() + }() + go func() { + defer p.workerWg.Done() + p.runBackgroundCleanupWorker() + }() +} + +type openAIWSIdlePingCandidate struct { + accountID int64 + conn *openAIWSConn +} + +func (p *openAIWSConnPool) runBackgroundPingWorker() { + if p == nil { + return + } + ticker := time.NewTicker(openAIWSBackgroundPingInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + p.runBackgroundPingSweep() + case <-p.workerStopCh: + return + } + } +} + +func (p *openAIWSConnPool) runBackgroundPingSweep() { + if p == nil { + return + } + candidates := p.snapshotIdleConnsForPing() + var g errgroup.Group + g.SetLimit(10) + for _, item := range candidates { + item := item + if item.conn == nil || item.conn.isLeased() || item.conn.waiters.Load() > 0 { + continue + } + g.Go(func() error { + if err := item.conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { + p.evictConn(item.accountID, item.conn.id) + } + return nil + }) + } + _ = g.Wait() +} + +func (p *openAIWSConnPool) snapshotIdleConnsForPing() []openAIWSIdlePingCandidate { + if p == nil { + return nil + } + candidates := make([]openAIWSIdlePingCandidate, 0) + p.accounts.Range(func(key, value any) bool { + accountID, ok := key.(int64) + if !ok || accountID <= 0 { + return true + } + ap, ok := value.(*openAIWSAccountPool) + if !ok || ap == nil { + return true + } + ap.mu.Lock() + for _, conn := range ap.conns { + if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 { + continue + } + candidates = append(candidates, openAIWSIdlePingCandidate{ + accountID: accountID, + conn: conn, + }) + } + ap.mu.Unlock() + return true + }) + return candidates +} + +func (p *openAIWSConnPool) runBackgroundCleanupWorker() { + if p == nil { + return + } + ticker := time.NewTicker(openAIWSBackgroundSweepTicker) + defer ticker.Stop() + for { + select { + case <-ticker.C: + p.runBackgroundCleanupSweep(time.Now()) + case <-p.workerStopCh: + return + } + } +} + +func (p *openAIWSConnPool) runBackgroundCleanupSweep(now time.Time) { + if p == nil { + return + } + type cleanupResult struct { + evicted []*openAIWSConn + } + results := make([]cleanupResult, 0) + p.accounts.Range(func(_ any, value any) bool { + ap, ok := value.(*openAIWSAccountPool) + if !ok || ap == nil { + return true + } + maxConns := p.maxConnsHardCap() + ap.mu.Lock() + if ap.lastAcquire != nil && ap.lastAcquire.Account != nil { + maxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account) + } + evicted := p.cleanupAccountLocked(ap, now, maxConns) + ap.lastCleanupAt = now + ap.mu.Unlock() + if len(evicted) > 0 { + results = append(results, cleanupResult{evicted: evicted}) + } + return true + }) + for _, result := range results { + closeOpenAIWSConns(result.evicted) + } +} + +func (p *openAIWSConnPool) Acquire(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConnLease, error) { + if p != nil { + p.metrics.acquireTotal.Add(1) + } + return p.acquire(ctx, cloneOpenAIWSAcquireRequest(req), 0) +} + +func (p *openAIWSConnPool) acquire(ctx context.Context, req openAIWSAcquireRequest, retry int) (*openAIWSConnLease, error) { + if p == nil || req.Account == nil || req.Account.ID <= 0 { + return nil, errors.New("invalid ws acquire request") + } + if stringsTrim(req.WSURL) == "" { + return nil, errors.New("ws url is empty") + } + + accountID := req.Account.ID + effectiveMaxConns := p.effectiveMaxConnsByAccount(req.Account) + if effectiveMaxConns <= 0 { + return nil, errOpenAIWSConnQueueFull + } + var evicted []*openAIWSConn + ap := p.getOrCreateAccountPool(accountID) + ap.mu.Lock() + ap.lastAcquire = cloneOpenAIWSAcquireRequestPtr(&req) + now := time.Now() + if ap.lastCleanupAt.IsZero() || now.Sub(ap.lastCleanupAt) >= openAIWSAcquireCleanupInterval { + evicted = p.cleanupAccountLocked(ap, now, effectiveMaxConns) + ap.lastCleanupAt = now + } + pickStartedAt := time.Now() + allowReuse := !req.ForceNewConn + preferredConnID := stringsTrim(req.PreferredConnID) + forcePreferredConn := allowReuse && req.ForcePreferredConn + + if allowReuse { + if forcePreferredConn { + if preferredConnID == "" { + p.recordConnPickDuration(time.Since(pickStartedAt)) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + return nil, errOpenAIWSPreferredConnUnavailable + } + preferredConn, ok := ap.conns[preferredConnID] + if !ok || preferredConn == nil { + p.recordConnPickDuration(time.Since(pickStartedAt)) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + return nil, errOpenAIWSPreferredConnUnavailable + } + if preferredConn.tryAcquire() { + connPick := time.Since(pickStartedAt) + p.recordConnPickDuration(connPick) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + if p.shouldHealthCheckConn(preferredConn) { + if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { + preferredConn.close() + p.evictConn(accountID, preferredConn.id) + if retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + } + lease := &openAIWSConnLease{ + pool: p, + accountID: accountID, + conn: preferredConn, + connPick: connPick, + reused: true, + } + p.metrics.acquireReuseTotal.Add(1) + p.ensureTargetIdleAsync(accountID) + return lease, nil + } + + connPick := time.Since(pickStartedAt) + p.recordConnPickDuration(connPick) + if int(preferredConn.waiters.Load()) >= p.queueLimitPerConn() { + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + return nil, errOpenAIWSConnQueueFull + } + preferredConn.waiters.Add(1) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + defer preferredConn.waiters.Add(-1) + waitStart := time.Now() + p.metrics.acquireQueueWaitTotal.Add(1) + + if err := preferredConn.acquire(ctx); err != nil { + if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + if p.shouldHealthCheckConn(preferredConn) { + if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { + preferredConn.release() + preferredConn.close() + p.evictConn(accountID, preferredConn.id) + if retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + } + + queueWait := time.Since(waitStart) + p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds()) + lease := &openAIWSConnLease{ + pool: p, + accountID: accountID, + conn: preferredConn, + queueWait: queueWait, + connPick: connPick, + reused: true, + } + p.metrics.acquireReuseTotal.Add(1) + p.ensureTargetIdleAsync(accountID) + return lease, nil + } + + if preferredConnID != "" { + if conn, ok := ap.conns[preferredConnID]; ok && conn.tryAcquire() { + connPick := time.Since(pickStartedAt) + p.recordConnPickDuration(connPick) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + if p.shouldHealthCheckConn(conn) { + if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { + conn.close() + p.evictConn(accountID, conn.id) + if retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + } + lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true} + p.metrics.acquireReuseTotal.Add(1) + p.ensureTargetIdleAsync(accountID) + return lease, nil + } + } + + best := p.pickLeastBusyConnLocked(ap, "") + if best != nil && best.tryAcquire() { + connPick := time.Since(pickStartedAt) + p.recordConnPickDuration(connPick) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + if p.shouldHealthCheckConn(best) { + if err := best.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { + best.close() + p.evictConn(accountID, best.id) + if retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + } + lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: best, connPick: connPick, reused: true} + p.metrics.acquireReuseTotal.Add(1) + p.ensureTargetIdleAsync(accountID) + return lease, nil + } + for _, conn := range ap.conns { + if conn == nil || conn == best { + continue + } + if conn.tryAcquire() { + connPick := time.Since(pickStartedAt) + p.recordConnPickDuration(connPick) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + if p.shouldHealthCheckConn(conn) { + if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { + conn.close() + p.evictConn(accountID, conn.id) + if retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + } + lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true} + p.metrics.acquireReuseTotal.Add(1) + p.ensureTargetIdleAsync(accountID) + return lease, nil + } + } + } + + if req.ForceNewConn && len(ap.conns)+ap.creating >= effectiveMaxConns { + if idle := p.pickOldestIdleConnLocked(ap); idle != nil { + delete(ap.conns, idle.id) + evicted = append(evicted, idle) + p.metrics.scaleDownTotal.Add(1) + } + } + + if len(ap.conns)+ap.creating < effectiveMaxConns { + connPick := time.Since(pickStartedAt) + p.recordConnPickDuration(connPick) + ap.creating++ + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + + conn, dialErr := p.dialConn(ctx, req) + + ap = p.getOrCreateAccountPool(accountID) + ap.mu.Lock() + ap.creating-- + if dialErr != nil { + ap.prewarmFails++ + ap.prewarmFailAt = time.Now() + ap.mu.Unlock() + return nil, dialErr + } + ap.conns[conn.id] = conn + ap.prewarmFails = 0 + ap.prewarmFailAt = time.Time{} + ap.mu.Unlock() + p.metrics.acquireCreateTotal.Add(1) + + if !conn.tryAcquire() { + if err := conn.acquire(ctx); err != nil { + conn.close() + p.evictConn(accountID, conn.id) + return nil, err + } + } + lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick} + p.ensureTargetIdleAsync(accountID) + return lease, nil + } + + if req.ForceNewConn { + p.recordConnPickDuration(time.Since(pickStartedAt)) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + return nil, errOpenAIWSConnQueueFull + } + + target := p.pickLeastBusyConnLocked(ap, req.PreferredConnID) + connPick := time.Since(pickStartedAt) + p.recordConnPickDuration(connPick) + if target == nil { + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + return nil, errOpenAIWSConnClosed + } + if int(target.waiters.Load()) >= p.queueLimitPerConn() { + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + return nil, errOpenAIWSConnQueueFull + } + target.waiters.Add(1) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + defer target.waiters.Add(-1) + waitStart := time.Now() + p.metrics.acquireQueueWaitTotal.Add(1) + + if err := target.acquire(ctx); err != nil { + if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + if p.shouldHealthCheckConn(target) { + if err := target.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { + target.release() + target.close() + p.evictConn(accountID, target.id) + if retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + } + + queueWait := time.Since(waitStart) + p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds()) + lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: target, queueWait: queueWait, connPick: connPick, reused: true} + p.metrics.acquireReuseTotal.Add(1) + p.ensureTargetIdleAsync(accountID) + return lease, nil +} + +func (p *openAIWSConnPool) recordConnPickDuration(duration time.Duration) { + if p == nil { + return + } + if duration < 0 { + duration = 0 + } + p.metrics.connPickTotal.Add(1) + p.metrics.connPickMs.Add(duration.Milliseconds()) +} + +func (p *openAIWSConnPool) pickOldestIdleConnLocked(ap *openAIWSAccountPool) *openAIWSConn { + if ap == nil || len(ap.conns) == 0 { + return nil + } + var oldest *openAIWSConn + for _, conn := range ap.conns { + if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) { + continue + } + if oldest == nil || conn.lastUsedAt().Before(oldest.lastUsedAt()) { + oldest = conn + } + } + return oldest +} + +func (p *openAIWSConnPool) getOrCreateAccountPool(accountID int64) *openAIWSAccountPool { + if p == nil || accountID <= 0 { + return nil + } + if existing, ok := p.accounts.Load(accountID); ok { + if ap, typed := existing.(*openAIWSAccountPool); typed && ap != nil { + return ap + } + } + ap := &openAIWSAccountPool{ + conns: make(map[string]*openAIWSConn), + pinnedConns: make(map[string]int), + } + actual, _ := p.accounts.LoadOrStore(accountID, ap) + if typed, ok := actual.(*openAIWSAccountPool); ok && typed != nil { + return typed + } + return ap +} + +// ensureAccountPoolLocked 兼容旧调用。 +func (p *openAIWSConnPool) ensureAccountPoolLocked(accountID int64) *openAIWSAccountPool { + return p.getOrCreateAccountPool(accountID) +} + +func (p *openAIWSConnPool) getAccountPool(accountID int64) (*openAIWSAccountPool, bool) { + if p == nil || accountID <= 0 { + return nil, false + } + value, ok := p.accounts.Load(accountID) + if !ok || value == nil { + return nil, false + } + ap, typed := value.(*openAIWSAccountPool) + return ap, typed && ap != nil +} + +func (p *openAIWSConnPool) isConnPinnedLocked(ap *openAIWSAccountPool, connID string) bool { + if ap == nil || connID == "" || len(ap.pinnedConns) == 0 { + return false + } + return ap.pinnedConns[connID] > 0 +} + +func (p *openAIWSConnPool) cleanupAccountLocked(ap *openAIWSAccountPool, now time.Time, maxConns int) []*openAIWSConn { + if ap == nil { + return nil + } + maxAge := p.maxConnAge() + + evicted := make([]*openAIWSConn, 0) + for id, conn := range ap.conns { + if conn == nil { + delete(ap.conns, id) + if len(ap.pinnedConns) > 0 { + delete(ap.pinnedConns, id) + } + continue + } + select { + case <-conn.closedCh: + delete(ap.conns, id) + if len(ap.pinnedConns) > 0 { + delete(ap.pinnedConns, id) + } + evicted = append(evicted, conn) + continue + default: + } + if p.isConnPinnedLocked(ap, id) { + continue + } + if maxAge > 0 && !conn.isLeased() && conn.age(now) > maxAge { + delete(ap.conns, id) + if len(ap.pinnedConns) > 0 { + delete(ap.pinnedConns, id) + } + evicted = append(evicted, conn) + } + } + + if maxConns <= 0 { + maxConns = p.maxConnsHardCap() + } + maxIdle := p.maxIdlePerAccount() + if maxIdle < 0 || maxIdle > maxConns { + maxIdle = maxConns + } + if maxIdle >= 0 && len(ap.conns) > maxIdle { + idleConns := make([]*openAIWSConn, 0, len(ap.conns)) + for id, conn := range ap.conns { + if conn == nil { + delete(ap.conns, id) + if len(ap.pinnedConns) > 0 { + delete(ap.pinnedConns, id) + } + continue + } + // 有等待者的连接不能在清理阶段被淘汰,否则等待中的 acquire 会收到 closed 错误。 + if conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) { + continue + } + idleConns = append(idleConns, conn) + } + sort.SliceStable(idleConns, func(i, j int) bool { + return idleConns[i].lastUsedAt().Before(idleConns[j].lastUsedAt()) + }) + redundant := len(ap.conns) - maxIdle + if redundant > len(idleConns) { + redundant = len(idleConns) + } + for i := 0; i < redundant; i++ { + conn := idleConns[i] + delete(ap.conns, conn.id) + if len(ap.pinnedConns) > 0 { + delete(ap.pinnedConns, conn.id) + } + evicted = append(evicted, conn) + } + if redundant > 0 { + p.metrics.scaleDownTotal.Add(int64(redundant)) + } + } + + return evicted +} + +func (p *openAIWSConnPool) pickLeastBusyConnLocked(ap *openAIWSAccountPool, preferredConnID string) *openAIWSConn { + if ap == nil || len(ap.conns) == 0 { + return nil + } + preferredConnID = stringsTrim(preferredConnID) + if preferredConnID != "" { + if conn, ok := ap.conns[preferredConnID]; ok { + return conn + } + } + var best *openAIWSConn + var bestWaiters int32 + var bestLastUsed time.Time + for _, conn := range ap.conns { + if conn == nil { + continue + } + waiters := conn.waiters.Load() + lastUsed := conn.lastUsedAt() + if best == nil || + waiters < bestWaiters || + (waiters == bestWaiters && lastUsed.Before(bestLastUsed)) { + best = conn + bestWaiters = waiters + bestLastUsed = lastUsed + } + } + return best +} + +func accountPoolLoadLocked(ap *openAIWSAccountPool) (inflight int, waiters int) { + if ap == nil { + return 0, 0 + } + for _, conn := range ap.conns { + if conn == nil { + continue + } + if conn.isLeased() { + inflight++ + } + waiters += int(conn.waiters.Load()) + } + return inflight, waiters +} + +// AccountPoolLoad 返回指定账号连接池的并发与排队快照。 +func (p *openAIWSConnPool) AccountPoolLoad(accountID int64) (inflight int, waiters int, conns int) { + if p == nil || accountID <= 0 { + return 0, 0, 0 + } + ap, ok := p.getAccountPool(accountID) + if !ok || ap == nil { + return 0, 0, 0 + } + ap.mu.Lock() + defer ap.mu.Unlock() + inflight, waiters = accountPoolLoadLocked(ap) + return inflight, waiters, len(ap.conns) +} + +func (p *openAIWSConnPool) ensureTargetIdleAsync(accountID int64) { + if p == nil || accountID <= 0 { + return + } + + var req openAIWSAcquireRequest + need := 0 + ap, ok := p.getAccountPool(accountID) + if !ok || ap == nil { + return + } + ap.mu.Lock() + defer ap.mu.Unlock() + if ap.lastAcquire == nil { + return + } + if ap.prewarmActive { + return + } + now := time.Now() + if !ap.prewarmUntil.IsZero() && now.Before(ap.prewarmUntil) { + return + } + if p.shouldSuppressPrewarmLocked(ap, now) { + return + } + effectiveMaxConns := p.maxConnsHardCap() + if ap.lastAcquire != nil && ap.lastAcquire.Account != nil { + effectiveMaxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account) + } + target := p.targetConnCountLocked(ap, effectiveMaxConns) + current := len(ap.conns) + ap.creating + if current >= target { + return + } + need = target - current + if need <= 0 { + return + } + req = cloneOpenAIWSAcquireRequest(*ap.lastAcquire) + ap.prewarmActive = true + if cooldown := p.prewarmCooldown(); cooldown > 0 { + ap.prewarmUntil = now.Add(cooldown) + } + ap.creating += need + p.metrics.scaleUpTotal.Add(int64(need)) + + go p.prewarmConns(accountID, req, need) +} + +func (p *openAIWSConnPool) targetConnCountLocked(ap *openAIWSAccountPool, maxConns int) int { + if ap == nil { + return 0 + } + + if maxConns <= 0 { + return 0 + } + + minIdle := p.minIdlePerAccount() + if minIdle < 0 { + minIdle = 0 + } + if minIdle > maxConns { + minIdle = maxConns + } + + inflight, waiters := accountPoolLoadLocked(ap) + utilization := p.targetUtilization() + demand := inflight + waiters + if demand <= 0 { + return minIdle + } + + target := 1 + if demand > 1 { + target = int(math.Ceil(float64(demand) / utilization)) + } + if waiters > 0 && target < len(ap.conns)+1 { + target = len(ap.conns) + 1 + } + if target < minIdle { + target = minIdle + } + if target > maxConns { + target = maxConns + } + return target +} + +func (p *openAIWSConnPool) prewarmConns(accountID int64, req openAIWSAcquireRequest, total int) { + defer func() { + if ap, ok := p.getAccountPool(accountID); ok && ap != nil { + ap.mu.Lock() + ap.prewarmActive = false + ap.mu.Unlock() + } + }() + + for i := 0; i < total; i++ { + ctx, cancel := context.WithTimeout(context.Background(), p.dialTimeout()+openAIWSConnPrewarmExtraDelay) + conn, err := p.dialConn(ctx, req) + cancel() + + ap, ok := p.getAccountPool(accountID) + if !ok || ap == nil { + if conn != nil { + conn.close() + } + return + } + ap.mu.Lock() + if ap.creating > 0 { + ap.creating-- + } + if err != nil { + ap.prewarmFails++ + ap.prewarmFailAt = time.Now() + ap.mu.Unlock() + continue + } + if len(ap.conns) >= p.effectiveMaxConnsByAccount(req.Account) { + ap.mu.Unlock() + conn.close() + continue + } + ap.conns[conn.id] = conn + ap.prewarmFails = 0 + ap.prewarmFailAt = time.Time{} + ap.mu.Unlock() + } +} + +func (p *openAIWSConnPool) evictConn(accountID int64, connID string) { + if p == nil || accountID <= 0 || stringsTrim(connID) == "" { + return + } + var conn *openAIWSConn + ap, ok := p.getAccountPool(accountID) + if ok && ap != nil { + ap.mu.Lock() + if c, exists := ap.conns[connID]; exists { + conn = c + delete(ap.conns, connID) + if len(ap.pinnedConns) > 0 { + delete(ap.pinnedConns, connID) + } + } + ap.mu.Unlock() + } + if conn != nil { + conn.close() + } +} + +func (p *openAIWSConnPool) PinConn(accountID int64, connID string) bool { + if p == nil || accountID <= 0 { + return false + } + connID = stringsTrim(connID) + if connID == "" { + return false + } + ap, ok := p.getAccountPool(accountID) + if !ok || ap == nil { + return false + } + ap.mu.Lock() + defer ap.mu.Unlock() + if _, exists := ap.conns[connID]; !exists { + return false + } + if ap.pinnedConns == nil { + ap.pinnedConns = make(map[string]int) + } + ap.pinnedConns[connID]++ + return true +} + +func (p *openAIWSConnPool) UnpinConn(accountID int64, connID string) { + if p == nil || accountID <= 0 { + return + } + connID = stringsTrim(connID) + if connID == "" { + return + } + ap, ok := p.getAccountPool(accountID) + if !ok || ap == nil { + return + } + ap.mu.Lock() + defer ap.mu.Unlock() + if len(ap.pinnedConns) == 0 { + return + } + count := ap.pinnedConns[connID] + if count <= 1 { + delete(ap.pinnedConns, connID) + return + } + ap.pinnedConns[connID] = count - 1 +} + +func (p *openAIWSConnPool) dialConn(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConn, error) { + if p == nil || p.clientDialer == nil { + return nil, errors.New("openai ws client dialer is nil") + } + conn, status, handshakeHeaders, err := p.clientDialer.Dial(ctx, req.WSURL, req.Headers, req.ProxyURL) + if err != nil { + return nil, &openAIWSDialError{ + StatusCode: status, + ResponseHeaders: cloneHeader(handshakeHeaders), + Err: err, + } + } + if conn == nil { + return nil, &openAIWSDialError{ + StatusCode: status, + ResponseHeaders: cloneHeader(handshakeHeaders), + Err: errors.New("openai ws dialer returned nil connection"), + } + } + id := p.nextConnID(req.Account.ID) + return newOpenAIWSConn(id, req.Account.ID, conn, handshakeHeaders), nil +} + +func (p *openAIWSConnPool) nextConnID(accountID int64) string { + seq := p.seq.Add(1) + buf := make([]byte, 0, 32) + buf = append(buf, "oa_ws_"...) + buf = strconv.AppendInt(buf, accountID, 10) + buf = append(buf, '_') + buf = strconv.AppendUint(buf, seq, 10) + return string(buf) +} + +func (p *openAIWSConnPool) shouldHealthCheckConn(conn *openAIWSConn) bool { + if conn == nil { + return false + } + return conn.idleDuration(time.Now()) >= openAIWSConnHealthCheckIdle +} + +func (p *openAIWSConnPool) maxConnsHardCap() int { + if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount > 0 { + return p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount + } + return 8 +} + +func (p *openAIWSConnPool) dynamicMaxConnsEnabled() bool { + if p != nil && p.cfg != nil { + return p.cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled + } + return false +} + +func (p *openAIWSConnPool) modeRouterV2Enabled() bool { + if p != nil && p.cfg != nil { + return p.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled + } + return false +} + +func (p *openAIWSConnPool) maxConnsFactorByAccount(account *Account) float64 { + if p == nil || p.cfg == nil || account == nil { + return 1.0 + } + switch account.Type { + case AccountTypeOAuth: + if p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor > 0 { + return p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor + } + case AccountTypeAPIKey: + if p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor > 0 { + return p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor + } + } + return 1.0 +} + +func (p *openAIWSConnPool) effectiveMaxConnsByAccount(account *Account) int { + hardCap := p.maxConnsHardCap() + if hardCap <= 0 { + return 0 + } + if p.modeRouterV2Enabled() { + if account == nil { + return hardCap + } + if account.Concurrency <= 0 { + return 0 + } + return account.Concurrency + } + if account == nil || !p.dynamicMaxConnsEnabled() { + return hardCap + } + if account.Concurrency <= 0 { + // 0/-1 等“无限制”并发场景下,仍由全局硬上限兜底。 + return hardCap + } + factor := p.maxConnsFactorByAccount(account) + if factor <= 0 { + factor = 1.0 + } + effective := int(math.Ceil(float64(account.Concurrency) * factor)) + if effective < 1 { + effective = 1 + } + if effective > hardCap { + effective = hardCap + } + return effective +} + +func (p *openAIWSConnPool) minIdlePerAccount() int { + if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MinIdlePerAccount >= 0 { + return p.cfg.Gateway.OpenAIWS.MinIdlePerAccount + } + return 0 +} + +func (p *openAIWSConnPool) maxIdlePerAccount() int { + if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount >= 0 { + return p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount + } + return 4 +} + +func (p *openAIWSConnPool) maxConnAge() time.Duration { + return openAIWSConnMaxAge +} + +func (p *openAIWSConnPool) queueLimitPerConn() int { + if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.QueueLimitPerConn > 0 { + return p.cfg.Gateway.OpenAIWS.QueueLimitPerConn + } + return 256 +} + +func (p *openAIWSConnPool) targetUtilization() float64 { + if p != nil && p.cfg != nil { + ratio := p.cfg.Gateway.OpenAIWS.PoolTargetUtilization + if ratio > 0 && ratio <= 1 { + return ratio + } + } + return 0.7 +} + +func (p *openAIWSConnPool) prewarmCooldown() time.Duration { + if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS > 0 { + return time.Duration(p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS) * time.Millisecond + } + return 0 +} + +func (p *openAIWSConnPool) shouldSuppressPrewarmLocked(ap *openAIWSAccountPool, now time.Time) bool { + if ap == nil { + return true + } + if ap.prewarmFails <= 0 { + return false + } + if ap.prewarmFailAt.IsZero() { + ap.prewarmFails = 0 + return false + } + if now.Sub(ap.prewarmFailAt) > openAIWSPrewarmFailureWindow { + ap.prewarmFails = 0 + ap.prewarmFailAt = time.Time{} + return false + } + return ap.prewarmFails >= openAIWSPrewarmFailureSuppress +} + +func (p *openAIWSConnPool) dialTimeout() time.Duration { + if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 { + return time.Duration(p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second + } + return 10 * time.Second +} + +func cloneOpenAIWSAcquireRequest(req openAIWSAcquireRequest) openAIWSAcquireRequest { + copied := req + copied.Headers = cloneHeader(req.Headers) + copied.WSURL = stringsTrim(req.WSURL) + copied.ProxyURL = stringsTrim(req.ProxyURL) + copied.PreferredConnID = stringsTrim(req.PreferredConnID) + return copied +} + +func cloneOpenAIWSAcquireRequestPtr(req *openAIWSAcquireRequest) *openAIWSAcquireRequest { + if req == nil { + return nil + } + copied := cloneOpenAIWSAcquireRequest(*req) + return &copied +} + +func cloneHeader(src http.Header) http.Header { + if src == nil { + return nil + } + dst := make(http.Header, len(src)) + for k, vals := range src { + if len(vals) == 0 { + dst[k] = nil + continue + } + copied := make([]string, len(vals)) + copy(copied, vals) + dst[k] = copied + } + return dst +} + +func closeOpenAIWSConns(conns []*openAIWSConn) { + if len(conns) == 0 { + return + } + for _, conn := range conns { + if conn == nil { + continue + } + conn.close() + } +} + +func stringsTrim(value string) string { + return strings.TrimSpace(value) +} diff --git a/backend/internal/service/openai_ws_pool_benchmark_test.go b/backend/internal/service/openai_ws_pool_benchmark_test.go new file mode 100644 index 00000000..bff74b62 --- /dev/null +++ b/backend/internal/service/openai_ws_pool_benchmark_test.go @@ -0,0 +1,58 @@ +package service + +import ( + "context" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +func BenchmarkOpenAIWSPoolAcquire(b *testing.B) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 256 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1 + + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(&openAIWSCountingDialer{}) + + account := &Account{ID: 1001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + req := openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + } + ctx := context.Background() + + lease, err := pool.Acquire(ctx, req) + if err != nil { + b.Fatalf("warm acquire failed: %v", err) + } + lease.Release() + + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + var ( + got *openAIWSConnLease + acquireErr error + ) + for retry := 0; retry < 3; retry++ { + got, acquireErr = pool.Acquire(ctx, req) + if acquireErr == nil { + break + } + if !errors.Is(acquireErr, errOpenAIWSConnClosed) { + break + } + } + if acquireErr != nil { + b.Fatalf("acquire failed: %v", acquireErr) + } + got.Release() + } + }) +} diff --git a/backend/internal/service/openai_ws_pool_test.go b/backend/internal/service/openai_ws_pool_test.go new file mode 100644 index 00000000..b2683ee0 --- /dev/null +++ b/backend/internal/service/openai_ws_pool_test.go @@ -0,0 +1,1709 @@ +package service + +import ( + "context" + "errors" + "net/http" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAIWSConnPool_CleanupStaleAndTrimIdle(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + pool := newOpenAIWSConnPool(cfg) + + accountID := int64(10) + ap := pool.getOrCreateAccountPool(accountID) + + stale := newOpenAIWSConn("stale", accountID, nil, nil) + stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) + stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) + + idleOld := newOpenAIWSConn("idle_old", accountID, nil, nil) + idleOld.lastUsedNano.Store(time.Now().Add(-10 * time.Minute).UnixNano()) + + idleNew := newOpenAIWSConn("idle_new", accountID, nil, nil) + idleNew.lastUsedNano.Store(time.Now().Add(-1 * time.Minute).UnixNano()) + + ap.conns[stale.id] = stale + ap.conns[idleOld.id] = idleOld + ap.conns[idleNew.id] = idleNew + + evicted := pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap()) + closeOpenAIWSConns(evicted) + + require.Nil(t, ap.conns["stale"], "stale connection should be rotated") + require.Nil(t, ap.conns["idle_old"], "old idle should be trimmed by max_idle") + require.NotNil(t, ap.conns["idle_new"], "newer idle should be kept") +} + +func TestOpenAIWSConnPool_NextConnIDFormat(t *testing.T) { + pool := newOpenAIWSConnPool(&config.Config{}) + id1 := pool.nextConnID(42) + id2 := pool.nextConnID(42) + + require.True(t, strings.HasPrefix(id1, "oa_ws_42_")) + require.True(t, strings.HasPrefix(id2, "oa_ws_42_")) + require.NotEqual(t, id1, id2) + require.Equal(t, "oa_ws_42_1", id1) + require.Equal(t, "oa_ws_42_2", id2) +} + +func TestOpenAIWSConnPool_AcquireCleanupInterval(t *testing.T) { + require.Equal(t, 3*time.Second, openAIWSAcquireCleanupInterval) + require.Less(t, openAIWSAcquireCleanupInterval, openAIWSBackgroundSweepTicker) +} + +func TestOpenAIWSConnLease_WriteJSONAndGuards(t *testing.T) { + conn := newOpenAIWSConn("lease_write", 1, &openAIWSFakeConn{}, nil) + lease := &openAIWSConnLease{conn: conn} + require.NoError(t, lease.WriteJSON(map[string]any{"type": "response.create"}, 0)) + + var nilLease *openAIWSConnLease + err := nilLease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + + err = (&openAIWSConnLease{}).WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestOpenAIWSConn_WriteJSONWithTimeout_NilParentContextUsesBackground(t *testing.T) { + probe := &openAIWSContextProbeConn{} + conn := newOpenAIWSConn("ctx_probe", 1, probe, nil) + require.NoError(t, conn.writeJSONWithTimeout(context.Background(), map[string]any{"type": "response.create"}, 0)) + require.NotNil(t, probe.lastWriteCtx) +} + +func TestOpenAIWSConnPool_TargetConnCountAdaptive(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 6 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.5 + + pool := newOpenAIWSConnPool(cfg) + ap := pool.getOrCreateAccountPool(88) + + conn1 := newOpenAIWSConn("c1", 88, nil, nil) + conn2 := newOpenAIWSConn("c2", 88, nil, nil) + require.True(t, conn1.tryAcquire()) + require.True(t, conn2.tryAcquire()) + conn1.waiters.Store(1) + conn2.waiters.Store(1) + + ap.conns[conn1.id] = conn1 + ap.conns[conn2.id] = conn2 + + target := pool.targetConnCountLocked(ap, pool.maxConnsHardCap()) + require.Equal(t, 6, target, "应按 inflight+waiters 与 target_utilization 自适应扩容到上限") + + conn1.release() + conn2.release() + conn1.waiters.Store(0) + conn2.waiters.Store(0) + target = pool.targetConnCountLocked(ap, pool.maxConnsHardCap()) + require.Equal(t, 1, target, "低负载时应缩回到最小空闲连接") +} + +func TestOpenAIWSConnPool_TargetConnCountMinIdleZero(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8 + + pool := newOpenAIWSConnPool(cfg) + ap := pool.getOrCreateAccountPool(66) + + target := pool.targetConnCountLocked(ap, pool.maxConnsHardCap()) + require.Equal(t, 0, target, "min_idle=0 且无负载时应允许缩容到 0") +} + +func TestOpenAIWSConnPool_EnsureTargetIdleAsync(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 2 + cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1 + + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(&openAIWSFakeDialer{}) + + accountID := int64(77) + account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap := pool.getOrCreateAccountPool(accountID) + ap.mu.Lock() + ap.lastAcquire = &openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + } + ap.mu.Unlock() + + pool.ensureTargetIdleAsync(accountID) + + require.Eventually(t, func() bool { + ap, ok := pool.getAccountPool(accountID) + if !ok || ap == nil { + return false + } + ap.mu.Lock() + defer ap.mu.Unlock() + return len(ap.conns) >= 2 + }, 2*time.Second, 20*time.Millisecond) + + metrics := pool.SnapshotMetrics() + require.GreaterOrEqual(t, metrics.ScaleUpTotal, int64(2)) +} + +func TestOpenAIWSConnPool_EnsureTargetIdleAsyncCooldown(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 2 + cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1 + cfg.Gateway.OpenAIWS.PrewarmCooldownMS = 500 + + pool := newOpenAIWSConnPool(cfg) + dialer := &openAIWSCountingDialer{} + pool.setClientDialerForTest(dialer) + + accountID := int64(178) + account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap := pool.getOrCreateAccountPool(accountID) + ap.mu.Lock() + ap.lastAcquire = &openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + } + ap.mu.Unlock() + + pool.ensureTargetIdleAsync(accountID) + require.Eventually(t, func() bool { + ap, ok := pool.getAccountPool(accountID) + if !ok || ap == nil { + return false + } + ap.mu.Lock() + defer ap.mu.Unlock() + return len(ap.conns) >= 2 && !ap.prewarmActive + }, 2*time.Second, 20*time.Millisecond) + firstDialCount := dialer.DialCount() + require.GreaterOrEqual(t, firstDialCount, 2) + + // 人工制造缺口触发新一轮预热需求。 + ap, ok := pool.getAccountPool(accountID) + require.True(t, ok) + require.NotNil(t, ap) + ap.mu.Lock() + for id := range ap.conns { + delete(ap.conns, id) + break + } + ap.mu.Unlock() + + pool.ensureTargetIdleAsync(accountID) + time.Sleep(120 * time.Millisecond) + require.Equal(t, firstDialCount, dialer.DialCount(), "cooldown 窗口内不应再次触发预热") + + time.Sleep(450 * time.Millisecond) + pool.ensureTargetIdleAsync(accountID) + require.Eventually(t, func() bool { + return dialer.DialCount() > firstDialCount + }, 2*time.Second, 20*time.Millisecond) +} + +func TestOpenAIWSConnPool_EnsureTargetIdleAsyncFailureSuppress(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1 + cfg.Gateway.OpenAIWS.PrewarmCooldownMS = 0 + + pool := newOpenAIWSConnPool(cfg) + dialer := &openAIWSAlwaysFailDialer{} + pool.setClientDialerForTest(dialer) + + accountID := int64(279) + account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap := pool.getOrCreateAccountPool(accountID) + ap.mu.Lock() + ap.lastAcquire = &openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + } + ap.mu.Unlock() + + pool.ensureTargetIdleAsync(accountID) + require.Eventually(t, func() bool { + ap, ok := pool.getAccountPool(accountID) + if !ok || ap == nil { + return false + } + ap.mu.Lock() + defer ap.mu.Unlock() + return !ap.prewarmActive + }, 2*time.Second, 20*time.Millisecond) + + pool.ensureTargetIdleAsync(accountID) + require.Eventually(t, func() bool { + ap, ok := pool.getAccountPool(accountID) + if !ok || ap == nil { + return false + } + ap.mu.Lock() + defer ap.mu.Unlock() + return !ap.prewarmActive + }, 2*time.Second, 20*time.Millisecond) + require.Equal(t, 2, dialer.DialCount()) + + // 连续失败达到阈值后,新的预热触发应被抑制,不再继续拨号。 + pool.ensureTargetIdleAsync(accountID) + time.Sleep(120 * time.Millisecond) + require.Equal(t, 2, dialer.DialCount()) +} + +func TestOpenAIWSConnPool_AcquireQueueWaitMetrics(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 4 + + pool := newOpenAIWSConnPool(cfg) + accountID := int64(99) + account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + conn := newOpenAIWSConn("busy", accountID, &openAIWSFakeConn{}, nil) + require.True(t, conn.tryAcquire()) // 占用连接,触发后续排队 + + ap := pool.ensureAccountPoolLocked(accountID) + ap.mu.Lock() + ap.conns[conn.id] = conn + ap.lastAcquire = &openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + } + ap.mu.Unlock() + + go func() { + time.Sleep(60 * time.Millisecond) + conn.release() + }() + + lease, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + }) + require.NoError(t, err) + require.NotNil(t, lease) + require.True(t, lease.Reused()) + require.GreaterOrEqual(t, lease.QueueWaitDuration(), 50*time.Millisecond) + lease.Release() + + metrics := pool.SnapshotMetrics() + require.GreaterOrEqual(t, metrics.AcquireQueueWaitTotal, int64(1)) + require.Greater(t, metrics.AcquireQueueWaitMsTotal, int64(0)) + require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1)) +} + +func TestOpenAIWSConnPool_ForceNewConnSkipsReuse(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + + pool := newOpenAIWSConnPool(cfg) + dialer := &openAIWSCountingDialer{} + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 123, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + lease1, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + lease1.Release() + + lease2, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + ForceNewConn: true, + }) + require.NoError(t, err) + require.NotNil(t, lease2) + lease2.Release() + + require.Equal(t, 2, dialer.DialCount(), "ForceNewConn=true 时应跳过空闲连接复用并新建连接") +} + +func TestOpenAIWSConnPool_AcquireForcePreferredConnUnavailable(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + + pool := newOpenAIWSConnPool(cfg) + account := &Account{ID: 124, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap := pool.getOrCreateAccountPool(account.ID) + otherConn := newOpenAIWSConn("other_conn", account.ID, &openAIWSFakeConn{}, nil) + ap.mu.Lock() + ap.conns[otherConn.id] = otherConn + ap.mu.Unlock() + + _, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + ForcePreferredConn: true, + }) + require.ErrorIs(t, err, errOpenAIWSPreferredConnUnavailable) + + _, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + PreferredConnID: "missing_conn", + ForcePreferredConn: true, + }) + require.ErrorIs(t, err, errOpenAIWSPreferredConnUnavailable) +} + +func TestOpenAIWSConnPool_AcquireForcePreferredConnQueuesOnPreferredOnly(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 4 + + pool := newOpenAIWSConnPool(cfg) + account := &Account{ID: 125, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap := pool.getOrCreateAccountPool(account.ID) + preferredConn := newOpenAIWSConn("preferred_conn", account.ID, &openAIWSFakeConn{}, nil) + otherConn := newOpenAIWSConn("other_conn_idle", account.ID, &openAIWSFakeConn{}, nil) + require.True(t, preferredConn.tryAcquire(), "先占用 preferred 连接,触发排队获取") + ap.mu.Lock() + ap.conns[preferredConn.id] = preferredConn + ap.conns[otherConn.id] = otherConn + ap.lastCleanupAt = time.Now() + ap.mu.Unlock() + + go func() { + time.Sleep(60 * time.Millisecond) + preferredConn.release() + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + lease, err := pool.Acquire(ctx, openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + PreferredConnID: preferredConn.id, + ForcePreferredConn: true, + }) + require.NoError(t, err) + require.NotNil(t, lease) + require.Equal(t, preferredConn.id, lease.ConnID(), "严格模式应只等待并复用 preferred 连接,不可漂移") + require.GreaterOrEqual(t, lease.QueueWaitDuration(), 40*time.Millisecond) + lease.Release() + require.True(t, otherConn.tryAcquire(), "other 连接不应被严格模式抢占") + otherConn.release() +} + +func TestOpenAIWSConnPool_AcquireForcePreferredConnDirectAndQueueFull(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 1 + + pool := newOpenAIWSConnPool(cfg) + account := &Account{ID: 127, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap := pool.getOrCreateAccountPool(account.ID) + preferredConn := newOpenAIWSConn("preferred_conn_direct", account.ID, &openAIWSFakeConn{}, nil) + otherConn := newOpenAIWSConn("other_conn_direct", account.ID, &openAIWSFakeConn{}, nil) + ap.mu.Lock() + ap.conns[preferredConn.id] = preferredConn + ap.conns[otherConn.id] = otherConn + ap.lastCleanupAt = time.Now() + ap.mu.Unlock() + + lease, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + PreferredConnID: preferredConn.id, + ForcePreferredConn: true, + }) + require.NoError(t, err) + require.Equal(t, preferredConn.id, lease.ConnID(), "preferred 空闲时应直接命中") + lease.Release() + + require.True(t, preferredConn.tryAcquire()) + preferredConn.waiters.Store(1) + _, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + PreferredConnID: preferredConn.id, + ForcePreferredConn: true, + }) + require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "严格模式下队列满应直接失败,不得漂移") + preferredConn.waiters.Store(0) + preferredConn.release() +} + +func TestOpenAIWSConnPool_CleanupSkipsPinnedConn(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 0 + + pool := newOpenAIWSConnPool(cfg) + accountID := int64(126) + ap := pool.getOrCreateAccountPool(accountID) + pinnedConn := newOpenAIWSConn("pinned_conn", accountID, &openAIWSFakeConn{}, nil) + idleConn := newOpenAIWSConn("idle_conn", accountID, &openAIWSFakeConn{}, nil) + ap.mu.Lock() + ap.conns[pinnedConn.id] = pinnedConn + ap.conns[idleConn.id] = idleConn + ap.mu.Unlock() + + require.True(t, pool.PinConn(accountID, pinnedConn.id)) + evicted := pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap()) + closeOpenAIWSConns(evicted) + + ap.mu.Lock() + _, pinnedExists := ap.conns[pinnedConn.id] + _, idleExists := ap.conns[idleConn.id] + ap.mu.Unlock() + require.True(t, pinnedExists, "被 active ingress 绑定的连接不应被 cleanup 回收") + require.False(t, idleExists, "非绑定的空闲连接应被回收") + + pool.UnpinConn(accountID, pinnedConn.id) + evicted = pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap()) + closeOpenAIWSConns(evicted) + ap.mu.Lock() + _, pinnedExists = ap.conns[pinnedConn.id] + ap.mu.Unlock() + require.False(t, pinnedExists, "解绑后连接应可被正常回收") +} + +func TestOpenAIWSConnPool_PinUnpinConnBranches(t *testing.T) { + var nilPool *openAIWSConnPool + require.False(t, nilPool.PinConn(1, "x")) + nilPool.UnpinConn(1, "x") + + cfg := &config.Config{} + pool := newOpenAIWSConnPool(cfg) + accountID := int64(128) + ap := &openAIWSAccountPool{ + conns: map[string]*openAIWSConn{}, + } + pool.accounts.Store(accountID, ap) + + require.False(t, pool.PinConn(0, "x")) + require.False(t, pool.PinConn(999, "x")) + require.False(t, pool.PinConn(accountID, "")) + require.False(t, pool.PinConn(accountID, "missing")) + + conn := newOpenAIWSConn("pin_refcount", accountID, &openAIWSFakeConn{}, nil) + ap.mu.Lock() + ap.conns[conn.id] = conn + ap.mu.Unlock() + require.True(t, pool.PinConn(accountID, conn.id)) + require.True(t, pool.PinConn(accountID, conn.id)) + + ap.mu.Lock() + require.Equal(t, 2, ap.pinnedConns[conn.id]) + ap.mu.Unlock() + + pool.UnpinConn(accountID, conn.id) + ap.mu.Lock() + require.Equal(t, 1, ap.pinnedConns[conn.id]) + ap.mu.Unlock() + + pool.UnpinConn(accountID, conn.id) + ap.mu.Lock() + _, exists := ap.pinnedConns[conn.id] + ap.mu.Unlock() + require.False(t, exists) + + pool.UnpinConn(accountID, conn.id) + pool.UnpinConn(accountID, "") + pool.UnpinConn(0, conn.id) + pool.UnpinConn(999, conn.id) +} + +func TestOpenAIWSConnPool_EffectiveMaxConnsByAccount(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8 + cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true + cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 1.0 + cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0.6 + + pool := newOpenAIWSConnPool(cfg) + + oauthHigh := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 10} + require.Equal(t, 8, pool.effectiveMaxConnsByAccount(oauthHigh), "应受全局硬上限约束") + + oauthLow := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 3} + require.Equal(t, 3, pool.effectiveMaxConnsByAccount(oauthLow)) + + apiKeyHigh := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 10} + require.Equal(t, 6, pool.effectiveMaxConnsByAccount(apiKeyHigh), "API Key 应按系数缩放") + + apiKeyLow := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1} + require.Equal(t, 1, pool.effectiveMaxConnsByAccount(apiKeyLow), "最小值应保持为 1") + + unlimited := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 0} + require.Equal(t, 8, pool.effectiveMaxConnsByAccount(unlimited), "无限并发应回退到全局硬上限") + + require.Equal(t, 8, pool.effectiveMaxConnsByAccount(nil), "缺少账号上下文应回退到全局硬上限") +} + +func TestOpenAIWSConnPool_EffectiveMaxConnsDisabledFallbackHardCap(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8 + cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = false + cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 1.0 + cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 1.0 + + pool := newOpenAIWSConnPool(cfg) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 2} + require.Equal(t, 8, pool.effectiveMaxConnsByAccount(account), "关闭动态模式后应保持旧行为") +} + +func TestOpenAIWSConnPool_EffectiveMaxConnsByAccount_ModeRouterV2UsesAccountConcurrency(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8 + cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true + cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 0.3 + cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0.6 + + pool := newOpenAIWSConnPool(cfg) + + high := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 20} + require.Equal(t, 20, pool.effectiveMaxConnsByAccount(high), "v2 路径应直接使用账号并发数作为池上限") + + nonPositive := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 0} + require.Equal(t, 0, pool.effectiveMaxConnsByAccount(nonPositive), "并发数<=0 时应不可调度") +} + +func TestOpenAIWSConnPool_AcquireRejectsWhenEffectiveMaxConnsIsZero(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8 + pool := newOpenAIWSConnPool(cfg) + + account := &Account{ID: 901, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 0} + _, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + }) + require.ErrorIs(t, err, errOpenAIWSConnQueueFull) +} + +func TestOpenAIWSConnLease_ReadMessageWithContextTimeout_PerRead(t *testing.T) { + conn := newOpenAIWSConn("timeout", 1, &openAIWSBlockingConn{readDelay: 80 * time.Millisecond}, nil) + lease := &openAIWSConnLease{conn: conn} + + _, err := lease.ReadMessageWithContextTimeout(context.Background(), 20*time.Millisecond) + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) + + payload, err := lease.ReadMessageWithContextTimeout(context.Background(), 150*time.Millisecond) + require.NoError(t, err) + require.Contains(t, string(payload), "response.completed") + + parentCtx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = lease.ReadMessageWithContextTimeout(parentCtx, 150*time.Millisecond) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) +} + +func TestOpenAIWSConnLease_WriteJSONWithContextTimeout_RespectsParentContext(t *testing.T) { + conn := newOpenAIWSConn("write_timeout_ctx", 1, &openAIWSWriteBlockingConn{}, nil) + lease := &openAIWSConnLease{conn: conn} + + parentCtx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(20 * time.Millisecond) + cancel() + }() + + start := time.Now() + err := lease.WriteJSONWithContextTimeout(parentCtx, map[string]any{"type": "response.create"}, 2*time.Minute) + elapsed := time.Since(start) + + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + require.Less(t, elapsed, 200*time.Millisecond) +} + +func TestOpenAIWSConnLease_PingWithTimeout(t *testing.T) { + conn := newOpenAIWSConn("ping_ok", 1, &openAIWSFakeConn{}, nil) + lease := &openAIWSConnLease{conn: conn} + require.NoError(t, lease.PingWithTimeout(50*time.Millisecond)) + + var nilLease *openAIWSConnLease + err := nilLease.PingWithTimeout(50 * time.Millisecond) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestOpenAIWSConn_ReadAndWriteCanProceedConcurrently(t *testing.T) { + conn := newOpenAIWSConn("full_duplex", 1, &openAIWSBlockingConn{readDelay: 120 * time.Millisecond}, nil) + + readDone := make(chan error, 1) + go func() { + _, err := conn.readMessageWithContextTimeout(context.Background(), 200*time.Millisecond) + readDone <- err + }() + + // 让读取先占用 readMu。 + time.Sleep(20 * time.Millisecond) + + start := time.Now() + err := conn.pingWithTimeout(50 * time.Millisecond) + elapsed := time.Since(start) + + require.NoError(t, err) + require.Less(t, elapsed, 80*time.Millisecond, "写路径不应被读锁长期阻塞") + require.NoError(t, <-readDone) +} + +func TestOpenAIWSConnPool_BackgroundPingSweep_EvictsDeadIdleConn(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + pool := newOpenAIWSConnPool(cfg) + + accountID := int64(301) + ap := pool.getOrCreateAccountPool(accountID) + conn := newOpenAIWSConn("dead_idle", accountID, &openAIWSPingFailConn{}, nil) + ap.mu.Lock() + ap.conns[conn.id] = conn + ap.mu.Unlock() + + pool.runBackgroundPingSweep() + + ap.mu.Lock() + _, exists := ap.conns[conn.id] + ap.mu.Unlock() + require.False(t, exists, "后台 ping 失败的空闲连接应被回收") +} + +func TestOpenAIWSConnPool_BackgroundCleanupSweep_WithoutAcquire(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + pool := newOpenAIWSConnPool(cfg) + + accountID := int64(302) + ap := pool.getOrCreateAccountPool(accountID) + stale := newOpenAIWSConn("stale_bg", accountID, &openAIWSFakeConn{}, nil) + stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) + stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) + ap.mu.Lock() + ap.conns[stale.id] = stale + ap.mu.Unlock() + + pool.runBackgroundCleanupSweep(time.Now()) + + ap.mu.Lock() + _, exists := ap.conns[stale.id] + ap.mu.Unlock() + require.False(t, exists, "后台清理应在无新 acquire 时也回收过期连接") +} + +func TestOpenAIWSConnPool_BackgroundWorkerGuardBranches(t *testing.T) { + var nilPool *openAIWSConnPool + require.NotPanics(t, func() { + nilPool.startBackgroundWorkers() + nilPool.runBackgroundPingWorker() + nilPool.runBackgroundPingSweep() + _ = nilPool.snapshotIdleConnsForPing() + nilPool.runBackgroundCleanupWorker() + nilPool.runBackgroundCleanupSweep(time.Now()) + }) + + poolNoStop := &openAIWSConnPool{} + require.NotPanics(t, func() { + poolNoStop.startBackgroundWorkers() + }) + + poolStopPing := &openAIWSConnPool{workerStopCh: make(chan struct{})} + pingDone := make(chan struct{}) + go func() { + poolStopPing.runBackgroundPingWorker() + close(pingDone) + }() + close(poolStopPing.workerStopCh) + select { + case <-pingDone: + case <-time.After(500 * time.Millisecond): + t.Fatal("runBackgroundPingWorker 未在 stop 信号后退出") + } + + poolStopCleanup := &openAIWSConnPool{workerStopCh: make(chan struct{})} + cleanupDone := make(chan struct{}) + go func() { + poolStopCleanup.runBackgroundCleanupWorker() + close(cleanupDone) + }() + close(poolStopCleanup.workerStopCh) + select { + case <-cleanupDone: + case <-time.After(500 * time.Millisecond): + t.Fatal("runBackgroundCleanupWorker 未在 stop 信号后退出") + } +} + +func TestOpenAIWSConnPool_SnapshotIdleConnsForPing_SkipsInvalidEntries(t *testing.T) { + pool := &openAIWSConnPool{} + pool.accounts.Store("invalid-key", &openAIWSAccountPool{}) + pool.accounts.Store(int64(123), "invalid-value") + + accountID := int64(123) + ap := &openAIWSAccountPool{ + conns: make(map[string]*openAIWSConn), + } + ap.conns["nil_conn"] = nil + + leased := newOpenAIWSConn("leased", accountID, &openAIWSFakeConn{}, nil) + require.True(t, leased.tryAcquire()) + ap.conns[leased.id] = leased + + waiting := newOpenAIWSConn("waiting", accountID, &openAIWSFakeConn{}, nil) + waiting.waiters.Store(1) + ap.conns[waiting.id] = waiting + + idle := newOpenAIWSConn("idle", accountID, &openAIWSFakeConn{}, nil) + ap.conns[idle.id] = idle + + pool.accounts.Store(accountID, ap) + candidates := pool.snapshotIdleConnsForPing() + require.Len(t, candidates, 1) + require.Equal(t, idle.id, candidates[0].conn.id) +} + +func TestOpenAIWSConnPool_RunBackgroundCleanupSweep_SkipsInvalidAndUsesAccountCap(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4 + cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true + + pool := &openAIWSConnPool{cfg: cfg} + pool.accounts.Store("bad-key", "bad-value") + + accountID := int64(2026) + ap := &openAIWSAccountPool{ + conns: make(map[string]*openAIWSConn), + } + ap.conns["nil_conn"] = nil + stale := newOpenAIWSConn("stale_bg_cleanup", accountID, &openAIWSFakeConn{}, nil) + stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) + stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) + ap.conns[stale.id] = stale + ap.lastAcquire = &openAIWSAcquireRequest{ + Account: &Account{ + ID: accountID, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + }, + } + pool.accounts.Store(accountID, ap) + + now := time.Now() + require.NotPanics(t, func() { + pool.runBackgroundCleanupSweep(now) + }) + + ap.mu.Lock() + _, nilConnExists := ap.conns["nil_conn"] + _, exists := ap.conns[stale.id] + lastCleanupAt := ap.lastCleanupAt + ap.mu.Unlock() + + require.False(t, nilConnExists, "后台清理应移除无效 nil 连接条目") + require.False(t, exists, "后台清理应清理过期连接") + require.Equal(t, now, lastCleanupAt) +} + +func TestOpenAIWSConnPool_QueueLimitPerConn_DefaultAndConfigured(t *testing.T) { + var nilPool *openAIWSConnPool + require.Equal(t, 256, nilPool.queueLimitPerConn()) + + pool := &openAIWSConnPool{cfg: &config.Config{}} + require.Equal(t, 256, pool.queueLimitPerConn()) + + pool.cfg.Gateway.OpenAIWS.QueueLimitPerConn = 9 + require.Equal(t, 9, pool.queueLimitPerConn()) +} + +func TestOpenAIWSConnPool_Close(t *testing.T) { + cfg := &config.Config{} + pool := newOpenAIWSConnPool(cfg) + + // Close 应该可以安全调用 + pool.Close() + + // workerStopCh 应已关闭 + select { + case <-pool.workerStopCh: + // 预期:channel 已关闭 + default: + t.Fatal("Close 后 workerStopCh 应已关闭") + } + + // 多次调用 Close 不应 panic + pool.Close() + + // nil pool 调用 Close 不应 panic + var nilPool *openAIWSConnPool + nilPool.Close() +} + +func TestOpenAIWSDialError_ErrorAndUnwrap(t *testing.T) { + baseErr := errors.New("boom") + dialErr := &openAIWSDialError{StatusCode: 502, Err: baseErr} + require.Contains(t, dialErr.Error(), "status=502") + require.ErrorIs(t, dialErr.Unwrap(), baseErr) + + noStatus := &openAIWSDialError{Err: baseErr} + require.Contains(t, noStatus.Error(), "boom") + + var nilDialErr *openAIWSDialError + require.Equal(t, "", nilDialErr.Error()) + require.NoError(t, nilDialErr.Unwrap()) +} + +func TestOpenAIWSConnLease_ReadWriteHelpersAndConnStats(t *testing.T) { + conn := newOpenAIWSConn("helper_conn", 1, &openAIWSFakeConn{}, http.Header{ + "X-Test": []string{" value "}, + }) + lease := &openAIWSConnLease{conn: conn} + + require.NoError(t, lease.WriteJSONContext(context.Background(), map[string]any{"type": "response.create"})) + payload, err := lease.ReadMessage(100 * time.Millisecond) + require.NoError(t, err) + require.Contains(t, string(payload), "response.completed") + + payload, err = lease.ReadMessageContext(context.Background()) + require.NoError(t, err) + require.Contains(t, string(payload), "response.completed") + + payload, err = conn.readMessageWithTimeout(100 * time.Millisecond) + require.NoError(t, err) + require.Contains(t, string(payload), "response.completed") + + require.Equal(t, "value", conn.handshakeHeader(" X-Test ")) + require.NotZero(t, conn.createdAt()) + require.NotZero(t, conn.lastUsedAt()) + require.GreaterOrEqual(t, conn.age(time.Now()), time.Duration(0)) + require.GreaterOrEqual(t, conn.idleDuration(time.Now()), time.Duration(0)) + require.False(t, conn.isLeased()) + + // 覆盖空上下文路径 + _, err = conn.readMessage(context.Background()) + require.NoError(t, err) + + // 覆盖 nil 保护分支 + var nilConn *openAIWSConn + require.ErrorIs(t, nilConn.writeJSONWithTimeout(context.Background(), map[string]any{}, time.Second), errOpenAIWSConnClosed) + _, err = nilConn.readMessageWithTimeout(10 * time.Millisecond) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + _, err = nilConn.readMessageWithContextTimeout(context.Background(), 10*time.Millisecond) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestOpenAIWSConnPool_PickOldestIdleAndAccountPoolLoad(t *testing.T) { + pool := &openAIWSConnPool{} + accountID := int64(404) + ap := &openAIWSAccountPool{conns: map[string]*openAIWSConn{}} + + idleOld := newOpenAIWSConn("idle_old", accountID, &openAIWSFakeConn{}, nil) + idleOld.lastUsedNano.Store(time.Now().Add(-10 * time.Minute).UnixNano()) + idleNew := newOpenAIWSConn("idle_new", accountID, &openAIWSFakeConn{}, nil) + idleNew.lastUsedNano.Store(time.Now().Add(-1 * time.Minute).UnixNano()) + leased := newOpenAIWSConn("leased", accountID, &openAIWSFakeConn{}, nil) + require.True(t, leased.tryAcquire()) + leased.waiters.Store(2) + + ap.conns[idleOld.id] = idleOld + ap.conns[idleNew.id] = idleNew + ap.conns[leased.id] = leased + + oldest := pool.pickOldestIdleConnLocked(ap) + require.NotNil(t, oldest) + require.Equal(t, idleOld.id, oldest.id) + + inflight, waiters := accountPoolLoadLocked(ap) + require.Equal(t, 1, inflight) + require.Equal(t, 2, waiters) + + pool.accounts.Store(accountID, ap) + loadInflight, loadWaiters, conns := pool.AccountPoolLoad(accountID) + require.Equal(t, 1, loadInflight) + require.Equal(t, 2, loadWaiters) + require.Equal(t, 3, conns) + + zeroInflight, zeroWaiters, zeroConns := pool.AccountPoolLoad(0) + require.Equal(t, 0, zeroInflight) + require.Equal(t, 0, zeroWaiters) + require.Equal(t, 0, zeroConns) +} + +func TestOpenAIWSConnPool_Close_WaitsWorkerGroupAndNilStopChannel(t *testing.T) { + pool := &openAIWSConnPool{} + release := make(chan struct{}) + pool.workerWg.Add(1) + go func() { + defer pool.workerWg.Done() + <-release + }() + + closed := make(chan struct{}) + go func() { + pool.Close() + close(closed) + }() + + select { + case <-closed: + t.Fatal("Close 不应在 WaitGroup 未完成时提前返回") + case <-time.After(30 * time.Millisecond): + } + + close(release) + select { + case <-closed: + case <-time.After(time.Second): + t.Fatal("Close 未等待 workerWg 完成") + } +} + +func TestOpenAIWSConnPool_Close_ClosesOnlyIdleConnections(t *testing.T) { + pool := &openAIWSConnPool{ + workerStopCh: make(chan struct{}), + } + + accountID := int64(606) + ap := &openAIWSAccountPool{ + conns: map[string]*openAIWSConn{}, + } + idle := newOpenAIWSConn("idle_conn", accountID, &openAIWSFakeConn{}, nil) + leased := newOpenAIWSConn("leased_conn", accountID, &openAIWSFakeConn{}, nil) + require.True(t, leased.tryAcquire()) + + ap.conns[idle.id] = idle + ap.conns[leased.id] = leased + pool.accounts.Store(accountID, ap) + pool.accounts.Store("invalid-key", "invalid-value") + + pool.Close() + + select { + case <-idle.closedCh: + // idle should be closed + default: + t.Fatal("空闲连接应在 Close 时被关闭") + } + + select { + case <-leased.closedCh: + t.Fatal("已租赁连接不应在 Close 时被关闭") + default: + } + + leased.release() + pool.Close() +} + +func TestOpenAIWSConnPool_RunBackgroundPingSweep_ConcurrencyLimit(t *testing.T) { + cfg := &config.Config{} + pool := newOpenAIWSConnPool(cfg) + accountID := int64(505) + ap := pool.getOrCreateAccountPool(accountID) + + var current atomic.Int32 + var maxConcurrent atomic.Int32 + release := make(chan struct{}) + for i := 0; i < 25; i++ { + conn := newOpenAIWSConn(pool.nextConnID(accountID), accountID, &openAIWSPingBlockingConn{ + current: ¤t, + maxConcurrent: &maxConcurrent, + release: release, + }, nil) + ap.mu.Lock() + ap.conns[conn.id] = conn + ap.mu.Unlock() + } + + done := make(chan struct{}) + go func() { + pool.runBackgroundPingSweep() + close(done) + }() + + require.Eventually(t, func() bool { + return maxConcurrent.Load() >= 10 + }, time.Second, 10*time.Millisecond) + + close(release) + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("runBackgroundPingSweep 未在释放后完成") + } + + require.LessOrEqual(t, maxConcurrent.Load(), int32(10)) +} + +func TestOpenAIWSConnLease_BasicGetterBranches(t *testing.T) { + var nilLease *openAIWSConnLease + require.Equal(t, "", nilLease.ConnID()) + require.Equal(t, time.Duration(0), nilLease.QueueWaitDuration()) + require.Equal(t, time.Duration(0), nilLease.ConnPickDuration()) + require.False(t, nilLease.Reused()) + require.Equal(t, "", nilLease.HandshakeHeader("x-test")) + require.False(t, nilLease.IsPrewarmed()) + nilLease.MarkPrewarmed() + nilLease.Release() + + conn := newOpenAIWSConn("getter_conn", 1, &openAIWSFakeConn{}, http.Header{"X-Test": []string{"ok"}}) + lease := &openAIWSConnLease{ + conn: conn, + queueWait: 3 * time.Millisecond, + connPick: 4 * time.Millisecond, + reused: true, + } + require.Equal(t, "getter_conn", lease.ConnID()) + require.Equal(t, 3*time.Millisecond, lease.QueueWaitDuration()) + require.Equal(t, 4*time.Millisecond, lease.ConnPickDuration()) + require.True(t, lease.Reused()) + require.Equal(t, "ok", lease.HandshakeHeader("x-test")) + require.False(t, lease.IsPrewarmed()) + lease.MarkPrewarmed() + require.True(t, lease.IsPrewarmed()) + lease.Release() +} + +func TestOpenAIWSConnPool_UtilityBranches(t *testing.T) { + var nilPool *openAIWSConnPool + require.Equal(t, OpenAIWSPoolMetricsSnapshot{}, nilPool.SnapshotMetrics()) + require.Equal(t, OpenAIWSTransportMetricsSnapshot{}, nilPool.SnapshotTransportMetrics()) + + pool := &openAIWSConnPool{cfg: &config.Config{}} + pool.metrics.acquireTotal.Store(7) + pool.metrics.acquireReuseTotal.Store(3) + metrics := pool.SnapshotMetrics() + require.Equal(t, int64(7), metrics.AcquireTotal) + require.Equal(t, int64(3), metrics.AcquireReuseTotal) + + // 非 transport metrics dialer 路径 + pool.clientDialer = &openAIWSFakeDialer{} + require.Equal(t, OpenAIWSTransportMetricsSnapshot{}, pool.SnapshotTransportMetrics()) + pool.setClientDialerForTest(nil) + require.NotNil(t, pool.clientDialer) + + require.Equal(t, 8, nilPool.maxConnsHardCap()) + require.False(t, nilPool.dynamicMaxConnsEnabled()) + require.Equal(t, 1.0, nilPool.maxConnsFactorByAccount(nil)) + require.Equal(t, 0, nilPool.minIdlePerAccount()) + require.Equal(t, 4, nilPool.maxIdlePerAccount()) + require.Equal(t, 256, nilPool.queueLimitPerConn()) + require.Equal(t, 0.7, nilPool.targetUtilization()) + require.Equal(t, time.Duration(0), nilPool.prewarmCooldown()) + require.Equal(t, 10*time.Second, nilPool.dialTimeout()) + + // shouldSuppressPrewarmLocked 覆盖 3 条分支 + now := time.Now() + apNilFail := &openAIWSAccountPool{prewarmFails: 1} + require.False(t, pool.shouldSuppressPrewarmLocked(apNilFail, now)) + apZeroTime := &openAIWSAccountPool{prewarmFails: 2} + require.False(t, pool.shouldSuppressPrewarmLocked(apZeroTime, now)) + require.Equal(t, 0, apZeroTime.prewarmFails) + apOldFail := &openAIWSAccountPool{prewarmFails: 2, prewarmFailAt: now.Add(-openAIWSPrewarmFailureWindow - time.Second)} + require.False(t, pool.shouldSuppressPrewarmLocked(apOldFail, now)) + apRecentFail := &openAIWSAccountPool{prewarmFails: openAIWSPrewarmFailureSuppress, prewarmFailAt: now} + require.True(t, pool.shouldSuppressPrewarmLocked(apRecentFail, now)) + + // recordConnPickDuration 的保护分支 + nilPool.recordConnPickDuration(10 * time.Millisecond) + pool.recordConnPickDuration(-10 * time.Millisecond) + require.Equal(t, int64(1), pool.metrics.connPickTotal.Load()) + + // account pool 读写分支 + require.Nil(t, nilPool.getOrCreateAccountPool(1)) + require.Nil(t, pool.getOrCreateAccountPool(0)) + pool.accounts.Store(int64(7), "invalid") + ap := pool.getOrCreateAccountPool(7) + require.NotNil(t, ap) + _, ok := pool.getAccountPool(0) + require.False(t, ok) + _, ok = pool.getAccountPool(12345) + require.False(t, ok) + pool.accounts.Store(int64(8), "bad-type") + _, ok = pool.getAccountPool(8) + require.False(t, ok) + + // health check 条件 + require.False(t, pool.shouldHealthCheckConn(nil)) + conn := newOpenAIWSConn("health", 1, &openAIWSFakeConn{}, nil) + conn.lastUsedNano.Store(time.Now().Add(-openAIWSConnHealthCheckIdle - time.Second).UnixNano()) + require.True(t, pool.shouldHealthCheckConn(conn)) +} + +func TestOpenAIWSConn_LeaseAndTimeHelpers_NilAndClosedBranches(t *testing.T) { + var nilConn *openAIWSConn + nilConn.touch() + require.Equal(t, time.Time{}, nilConn.createdAt()) + require.Equal(t, time.Time{}, nilConn.lastUsedAt()) + require.Equal(t, time.Duration(0), nilConn.idleDuration(time.Now())) + require.Equal(t, time.Duration(0), nilConn.age(time.Now())) + require.False(t, nilConn.isLeased()) + require.False(t, nilConn.isPrewarmed()) + nilConn.markPrewarmed() + + conn := newOpenAIWSConn("lease_state", 1, &openAIWSFakeConn{}, nil) + require.True(t, conn.tryAcquire()) + require.True(t, conn.isLeased()) + conn.release() + require.False(t, conn.isLeased()) + conn.close() + require.False(t, conn.tryAcquire()) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := conn.acquire(ctx) + require.Error(t, err) +} + +func TestOpenAIWSConnLease_ReadWriteNilConnBranches(t *testing.T) { + lease := &openAIWSConnLease{} + require.ErrorIs(t, lease.WriteJSON(map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed) + require.ErrorIs(t, lease.WriteJSONContext(context.Background(), map[string]any{"k": "v"}), errOpenAIWSConnClosed) + _, err := lease.ReadMessage(10 * time.Millisecond) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + _, err = lease.ReadMessageContext(context.Background()) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + _, err = lease.ReadMessageWithContextTimeout(context.Background(), 10*time.Millisecond) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestOpenAIWSConnLease_ReleasedLeaseGuards(t *testing.T) { + conn := newOpenAIWSConn("released_guard", 1, &openAIWSFakeConn{}, nil) + lease := &openAIWSConnLease{conn: conn} + + require.NoError(t, lease.PingWithTimeout(50*time.Millisecond)) + + lease.Release() + lease.Release() // idempotent + + require.ErrorIs(t, lease.WriteJSON(map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed) + require.ErrorIs(t, lease.WriteJSONContext(context.Background(), map[string]any{"k": "v"}), errOpenAIWSConnClosed) + require.ErrorIs(t, lease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed) + + _, err := lease.ReadMessage(10 * time.Millisecond) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + _, err = lease.ReadMessageContext(context.Background()) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + _, err = lease.ReadMessageWithContextTimeout(context.Background(), 10*time.Millisecond) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + + require.ErrorIs(t, lease.PingWithTimeout(50*time.Millisecond), errOpenAIWSConnClosed) +} + +func TestOpenAIWSConnLease_MarkBrokenAfterRelease_NoEviction(t *testing.T) { + conn := newOpenAIWSConn("released_markbroken", 7, &openAIWSFakeConn{}, nil) + ap := &openAIWSAccountPool{ + conns: map[string]*openAIWSConn{ + conn.id: conn, + }, + } + pool := &openAIWSConnPool{} + pool.accounts.Store(int64(7), ap) + + lease := &openAIWSConnLease{ + pool: pool, + accountID: 7, + conn: conn, + } + + lease.Release() + lease.MarkBroken() + + ap.mu.Lock() + _, exists := ap.conns[conn.id] + ap.mu.Unlock() + require.True(t, exists, "released lease should not evict active pool connection") +} + +func TestOpenAIWSConn_AdditionalGuardBranches(t *testing.T) { + var nilConn *openAIWSConn + require.False(t, nilConn.tryAcquire()) + require.ErrorIs(t, nilConn.acquire(context.Background()), errOpenAIWSConnClosed) + nilConn.release() + nilConn.close() + require.Equal(t, "", nilConn.handshakeHeader("x-test")) + + connBusy := newOpenAIWSConn("busy_ctx", 1, &openAIWSFakeConn{}, nil) + require.True(t, connBusy.tryAcquire()) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + require.ErrorIs(t, connBusy.acquire(ctx), context.Canceled) + connBusy.release() + + connClosed := newOpenAIWSConn("closed_guard", 1, &openAIWSFakeConn{}, nil) + connClosed.close() + require.ErrorIs( + t, + connClosed.writeJSONWithTimeout(context.Background(), map[string]any{"k": "v"}, time.Second), + errOpenAIWSConnClosed, + ) + _, err := connClosed.readMessageWithContextTimeout(context.Background(), time.Second) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + require.ErrorIs(t, connClosed.pingWithTimeout(time.Second), errOpenAIWSConnClosed) + + connNoWS := newOpenAIWSConn("no_ws", 1, nil, nil) + require.ErrorIs(t, connNoWS.writeJSON(map[string]any{"k": "v"}, context.Background()), errOpenAIWSConnClosed) + _, err = connNoWS.readMessage(context.Background()) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + require.ErrorIs(t, connNoWS.pingWithTimeout(time.Second), errOpenAIWSConnClosed) + require.Equal(t, "", connNoWS.handshakeHeader("x-test")) + + connOK := newOpenAIWSConn("ok", 1, &openAIWSFakeConn{}, nil) + require.NoError(t, connOK.writeJSON(map[string]any{"k": "v"}, nil)) + _, err = connOK.readMessageWithContextTimeout(context.Background(), 0) + require.NoError(t, err) + require.NoError(t, connOK.pingWithTimeout(0)) + + connZero := newOpenAIWSConn("zero_ts", 1, &openAIWSFakeConn{}, nil) + connZero.createdAtNano.Store(0) + connZero.lastUsedNano.Store(0) + require.True(t, connZero.createdAt().IsZero()) + require.True(t, connZero.lastUsedAt().IsZero()) + require.Equal(t, time.Duration(0), connZero.idleDuration(time.Now())) + require.Equal(t, time.Duration(0), connZero.age(time.Now())) + + require.Nil(t, cloneOpenAIWSAcquireRequestPtr(nil)) + copied := cloneHeader(http.Header{ + "X-Empty": []string{}, + "X-Test": []string{"v1"}, + }) + require.Contains(t, copied, "X-Empty") + require.Nil(t, copied["X-Empty"]) + require.Equal(t, "v1", copied.Get("X-Test")) + + closeOpenAIWSConns([]*openAIWSConn{nil, connOK}) +} + +func TestOpenAIWSConnLease_MarkBrokenEvictsConn(t *testing.T) { + pool := newOpenAIWSConnPool(&config.Config{}) + accountID := int64(5001) + conn := newOpenAIWSConn("broken_me", accountID, &openAIWSFakeConn{}, nil) + ap := pool.getOrCreateAccountPool(accountID) + ap.mu.Lock() + ap.conns[conn.id] = conn + ap.mu.Unlock() + + lease := &openAIWSConnLease{ + pool: pool, + accountID: accountID, + conn: conn, + } + lease.MarkBroken() + + ap.mu.Lock() + _, exists := ap.conns[conn.id] + ap.mu.Unlock() + require.False(t, exists) + require.False(t, conn.tryAcquire(), "被标记为 broken 的连接应被关闭") +} + +func TestOpenAIWSConnPool_TargetConnCountAndPrewarmBranches(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + pool := newOpenAIWSConnPool(cfg) + + require.Equal(t, 0, pool.targetConnCountLocked(nil, 1)) + ap := &openAIWSAccountPool{conns: map[string]*openAIWSConn{}} + require.Equal(t, 0, pool.targetConnCountLocked(ap, 0)) + + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 3 + require.Equal(t, 1, pool.targetConnCountLocked(ap, 1), "minIdle 应被 maxConns 截断") + + // 覆盖 waiters>0 且 target 需要至少 len(conns)+1 的分支 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.9 + busy := newOpenAIWSConn("busy_target", 2, &openAIWSFakeConn{}, nil) + require.True(t, busy.tryAcquire()) + busy.waiters.Store(1) + ap.conns[busy.id] = busy + target := pool.targetConnCountLocked(ap, 4) + require.GreaterOrEqual(t, target, len(ap.conns)+1) + + // prewarm: account pool 缺失时,拨号后的连接应被关闭并提前返回 + req := openAIWSAcquireRequest{ + Account: &Account{ID: 999, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}, + WSURL: "wss://example.com/v1/responses", + } + pool.prewarmConns(999, req, 1) + + // prewarm: 拨号失败分支(prewarmFails 累加) + accountID := int64(1000) + failPool := newOpenAIWSConnPool(cfg) + failPool.setClientDialerForTest(&openAIWSAlwaysFailDialer{}) + apFail := failPool.getOrCreateAccountPool(accountID) + apFail.mu.Lock() + apFail.creating = 1 + apFail.mu.Unlock() + req.Account.ID = accountID + failPool.prewarmConns(accountID, req, 1) + apFail.mu.Lock() + require.GreaterOrEqual(t, apFail.prewarmFails, 1) + apFail.mu.Unlock() +} + +func TestOpenAIWSConnPool_Acquire_ErrorBranches(t *testing.T) { + var nilPool *openAIWSConnPool + _, err := nilPool.Acquire(context.Background(), openAIWSAcquireRequest{}) + require.Error(t, err) + + pool := newOpenAIWSConnPool(&config.Config{}) + _, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: &Account{ID: 1}, + WSURL: " ", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "ws url is empty") + + // target=nil 分支:池满且仅有 nil 连接 + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 1 + fullPool := newOpenAIWSConnPool(cfg) + account := &Account{ID: 2001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap := fullPool.getOrCreateAccountPool(account.ID) + ap.mu.Lock() + ap.conns["nil"] = nil + ap.lastCleanupAt = time.Now() + ap.mu.Unlock() + _, err = fullPool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + }) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + + // queue full 分支:waiters 达上限 + account2 := &Account{ID: 2002, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap2 := fullPool.getOrCreateAccountPool(account2.ID) + conn := newOpenAIWSConn("queue_full", account2.ID, &openAIWSFakeConn{}, nil) + require.True(t, conn.tryAcquire()) + conn.waiters.Store(1) + ap2.mu.Lock() + ap2.conns[conn.id] = conn + ap2.lastCleanupAt = time.Now() + ap2.mu.Unlock() + _, err = fullPool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account2, + WSURL: "wss://example.com/v1/responses", + }) + require.ErrorIs(t, err, errOpenAIWSConnQueueFull) +} + +type openAIWSFakeDialer struct{} + +func (d *openAIWSFakeDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = headers + _ = proxyURL + return &openAIWSFakeConn{}, 0, nil, nil +} + +type openAIWSCountingDialer struct { + mu sync.Mutex + dialCount int +} + +type openAIWSAlwaysFailDialer struct { + mu sync.Mutex + dialCount int +} + +type openAIWSPingBlockingConn struct { + current *atomic.Int32 + maxConcurrent *atomic.Int32 + release <-chan struct{} +} + +func (c *openAIWSPingBlockingConn) WriteJSON(context.Context, any) error { + return nil +} + +func (c *openAIWSPingBlockingConn) ReadMessage(context.Context) ([]byte, error) { + return []byte(`{"type":"response.completed","response":{"id":"resp_blocking_ping"}}`), nil +} + +func (c *openAIWSPingBlockingConn) Ping(ctx context.Context) error { + if c.current == nil || c.maxConcurrent == nil { + return nil + } + + now := c.current.Add(1) + for { + prev := c.maxConcurrent.Load() + if now <= prev || c.maxConcurrent.CompareAndSwap(prev, now) { + break + } + } + defer c.current.Add(-1) + + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.release: + return nil + } +} + +func (c *openAIWSPingBlockingConn) Close() error { + return nil +} + +func (d *openAIWSCountingDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = headers + _ = proxyURL + d.mu.Lock() + d.dialCount++ + d.mu.Unlock() + return &openAIWSFakeConn{}, 0, nil, nil +} + +func (d *openAIWSCountingDialer) DialCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.dialCount +} + +func (d *openAIWSAlwaysFailDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = headers + _ = proxyURL + d.mu.Lock() + d.dialCount++ + d.mu.Unlock() + return nil, 503, nil, errors.New("dial failed") +} + +func (d *openAIWSAlwaysFailDialer) DialCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.dialCount +} + +type openAIWSFakeConn struct { + mu sync.Mutex + closed bool + payload [][]byte +} + +func (c *openAIWSFakeConn) WriteJSON(ctx context.Context, value any) error { + _ = ctx + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return errors.New("closed") + } + c.payload = append(c.payload, []byte("ok")) + _ = value + return nil +} + +func (c *openAIWSFakeConn) ReadMessage(ctx context.Context) ([]byte, error) { + _ = ctx + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return nil, errors.New("closed") + } + return []byte(`{"type":"response.completed","response":{"id":"resp_fake"}}`), nil +} + +func (c *openAIWSFakeConn) Ping(ctx context.Context) error { + _ = ctx + return nil +} + +func (c *openAIWSFakeConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true + return nil +} + +type openAIWSBlockingConn struct { + readDelay time.Duration +} + +func (c *openAIWSBlockingConn) WriteJSON(ctx context.Context, value any) error { + _ = ctx + _ = value + return nil +} + +func (c *openAIWSBlockingConn) ReadMessage(ctx context.Context) ([]byte, error) { + delay := c.readDelay + if delay <= 0 { + delay = 10 * time.Millisecond + } + timer := time.NewTimer(delay) + defer timer.Stop() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + return []byte(`{"type":"response.completed","response":{"id":"resp_blocking"}}`), nil + } +} + +func (c *openAIWSBlockingConn) Ping(ctx context.Context) error { + _ = ctx + return nil +} + +func (c *openAIWSBlockingConn) Close() error { + return nil +} + +type openAIWSWriteBlockingConn struct{} + +func (c *openAIWSWriteBlockingConn) WriteJSON(ctx context.Context, _ any) error { + <-ctx.Done() + return ctx.Err() +} + +func (c *openAIWSWriteBlockingConn) ReadMessage(context.Context) ([]byte, error) { + return []byte(`{"type":"response.completed","response":{"id":"resp_write_block"}}`), nil +} + +func (c *openAIWSWriteBlockingConn) Ping(context.Context) error { + return nil +} + +func (c *openAIWSWriteBlockingConn) Close() error { + return nil +} + +type openAIWSPingFailConn struct{} + +func (c *openAIWSPingFailConn) WriteJSON(context.Context, any) error { + return nil +} + +func (c *openAIWSPingFailConn) ReadMessage(context.Context) ([]byte, error) { + return []byte(`{"type":"response.completed","response":{"id":"resp_ping_fail"}}`), nil +} + +func (c *openAIWSPingFailConn) Ping(context.Context) error { + return errors.New("ping failed") +} + +func (c *openAIWSPingFailConn) Close() error { + return nil +} + +type openAIWSContextProbeConn struct { + lastWriteCtx context.Context +} + +func (c *openAIWSContextProbeConn) WriteJSON(ctx context.Context, _ any) error { + c.lastWriteCtx = ctx + return nil +} + +func (c *openAIWSContextProbeConn) ReadMessage(context.Context) ([]byte, error) { + return []byte(`{"type":"response.completed","response":{"id":"resp_ctx_probe"}}`), nil +} + +func (c *openAIWSContextProbeConn) Ping(context.Context) error { + return nil +} + +func (c *openAIWSContextProbeConn) Close() error { + return nil +} + +type openAIWSNilConnDialer struct{} + +func (d *openAIWSNilConnDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = headers + _ = proxyURL + return nil, 200, nil, nil +} + +func TestOpenAIWSConnPool_DialConnNilConnection(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1 + + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(&openAIWSNilConnDialer{}) + account := &Account{ID: 91, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + _, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "nil connection") +} + +func TestOpenAIWSConnPool_SnapshotTransportMetrics(t *testing.T) { + cfg := &config.Config{} + pool := newOpenAIWSConnPool(cfg) + + dialer, ok := pool.clientDialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + _, err := dialer.proxyHTTPClient("http://127.0.0.1:28080") + require.NoError(t, err) + _, err = dialer.proxyHTTPClient("http://127.0.0.1:28080") + require.NoError(t, err) + _, err = dialer.proxyHTTPClient("http://127.0.0.1:28081") + require.NoError(t, err) + + snapshot := pool.SnapshotTransportMetrics() + require.Equal(t, int64(1), snapshot.ProxyClientCacheHits) + require.Equal(t, int64(2), snapshot.ProxyClientCacheMisses) + require.InDelta(t, 1.0/3.0, snapshot.TransportReuseRatio, 0.0001) +} diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go new file mode 100644 index 00000000..df4d4871 --- /dev/null +++ b/backend/internal/service/openai_ws_protocol_forward_test.go @@ -0,0 +1,1218 @@ +package service + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestOpenAIGatewayService_Forward_PreservePreviousResponseIDWhenWSEnabled(t *testing.T) { + gin.SetMode(gin.TestMode) + wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer wsFallbackServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 1, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsFallbackServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 模式下失败时不应回退 HTTP") +} + +func TestOpenAIGatewayService_Forward_HTTPIngressStaysHTTPWhenWSEnabled(t *testing.T) { + gin.SetMode(gin.TestMode) + wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer wsFallbackServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportHTTP) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 101, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsFallbackServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_http_keep","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.OpenAIWSMode, "HTTP 入站应保持 HTTP 转发") + require.NotNil(t, upstream.lastReq, "HTTP 入站应命中 HTTP 上游") + require.False(t, gjson.GetBytes(upstream.lastBody, "previous_response_id").Exists(), "HTTP 路径应沿用原逻辑移除 previous_response_id") + + decision, _ := c.Get("openai_ws_transport_decision") + reason, _ := c.Get("openai_ws_transport_reason") + require.Equal(t, string(OpenAIUpstreamTransportHTTPSSE), decision) + require.Equal(t, "client_protocol_http", reason) +} + +func TestOpenAIGatewayService_Forward_RemovePreviousResponseIDWhenWSDisabled(t *testing.T) { + gin.SetMode(gin.TestMode) + wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer wsFallbackServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = false + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 1, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsFallbackServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, gjson.GetBytes(upstream.lastBody, "previous_response_id").Exists()) +} + +func TestOpenAIGatewayService_Forward_WSv2Dial426FallbackHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + ws426Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUpgradeRequired) + _, _ = w.Write([]byte(`upgrade required`)) + })) + defer ws426Server.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":8,"output_tokens":9,"input_tokens_details":{"cached_tokens":1}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 12, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": ws426Server.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_426","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "upgrade_required") + require.Nil(t, upstream.lastReq, "WS 模式下不应再回退 HTTP") + require.Equal(t, http.StatusUpgradeRequired, rec.Code) + require.Contains(t, rec.Body.String(), "426") +} + +func TestOpenAIGatewayService_Forward_WSv2FallbackCoolingSkipWS(t *testing.T) { + gin.SetMode(gin.TestMode) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":2,"output_tokens":3,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 30 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 21, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + svc.markOpenAIWSFallbackCooling(account.ID, "upgrade_required") + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_cooling","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 模式下不应再回退 HTTP") + + _, ok := c.Get("openai_ws_fallback_cooling") + require.False(t, ok, "已移除 fallback cooling 快捷回退路径") +} + +func TestOpenAIGatewayService_Forward_ReturnErrorWhenOnlyWSv1Enabled(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsockets = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 31, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://api.openai.com/v1/responses", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_v1","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "ws v1") + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "WSv1") + require.Nil(t, upstream.lastReq, "WSv1 不支持时不应触发 HTTP 上游请求") +} + +func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { + cfg := &config.Config{} + svc := NewOpenAIGatewayService( + nil, + nil, + nil, + nil, + nil, + cfg, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + ) + + decision := svc.getOpenAIWSProtocolResolver().Resolve(nil) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_missing", decision.Reason) +} + +func TestOpenAIGatewayService_Forward_WSv2FallbackWhenResponseAlreadyWrittenReturnsWSError(t *testing.T) { + gin.SetMode(gin.TestMode) + ws426Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUpgradeRequired) + _, _ = w.Write([]byte(`upgrade required`)) + })) + defer ws426Server.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + c.String(http.StatusAccepted, "already-written") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 41, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": ws426Server.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "ws fallback") + require.Nil(t, upstream.lastReq, "已写下游响应时,不应再回退 HTTP") +} + +func TestOpenAIGatewayService_Forward_WSv2StreamEarlyCloseFallbackHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + + // 仅发送 response.created(非 token 事件)后立即关闭, + // 模拟线上“上游早期内部错误断连”的场景。 + if err := conn.WriteJSON(map[string]any{ + "type": "response.created", + "response": map[string]any{ + "id": "resp_ws_created_only", + "model": "gpt-5.3-codex", + }, + }); err != nil { + t.Errorf("write response.created failed: %v", err) + return + } + closePayload := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "") + _ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second)) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.output_text.delta\",\"delta\":\"ok\"}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_http_fallback\",\"usage\":{\"input_tokens\":2,\"output_tokens\":1}}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 88, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":true,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 早期断连后不应再回退 HTTP") + require.Empty(t, rec.Body.String(), "未产出 token 前上游断连时不应写入下游半截流") +} + +func TestOpenAIGatewayService_Forward_WSv2RetryFiveTimesThenFallbackHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + closePayload := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "") + _ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second)) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.output_text.delta\",\"delta\":\"ok\"}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_retry_http_fallback\",\"usage\":{\"input_tokens\":2,\"output_tokens\":1}}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 89, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":true,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 重连耗尽后不应再回退 HTTP") + require.Equal(t, int32(openAIWSReconnectRetryLimit+1), wsAttempts.Load()) +} + +func TestOpenAIGatewayService_Forward_WSv2PolicyViolationFastFallbackHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + closePayload := websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "") + _ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second)) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_policy_fallback","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + cfg.Gateway.OpenAIWS.RetryBackoffInitialMS = 1 + cfg.Gateway.OpenAIWS.RetryBackoffMaxMS = 2 + cfg.Gateway.OpenAIWS.RetryJitterRatio = 0 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 8901, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "策略违规关闭后不应回退 HTTP") + require.Equal(t, int32(1), wsAttempts.Load(), "策略违规不应进行 WS 重试") +} + +func TestOpenAIGatewayService_Forward_WSv2ConnectionLimitReachedRetryThenFallbackHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "websocket_connection_limit_reached", + "type": "server_error", + "message": "websocket connection limit reached", + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_retry_limit","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 90, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "触发 websocket_connection_limit_reached 后不应回退 HTTP") + require.Equal(t, int32(openAIWSReconnectRetryLimit+1), wsAttempts.Load()) +} + +func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundRecoversByDroppingPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + var wsRequestPayloads [][]byte + var wsRequestMu sync.Mutex + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempt := wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + reqRaw, _ := json.Marshal(req) + wsRequestMu.Lock() + wsRequestPayloads = append(wsRequestPayloads, reqRaw) + wsRequestMu.Unlock() + if attempt == 1 { + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "previous_response_not_found", + "type": "invalid_request_error", + "message": "previous response not found", + }, + }) + return + } + _ = conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": "resp_ws_prev_recover_ok", + "model": "gpt-5.3-codex", + "usage": map[string]any{ + "input_tokens": 1, + "output_tokens": 1, + "input_tokens_details": map[string]any{ + "cached_tokens": 0, + }, + }, + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 91, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_ws_prev_recover_ok", result.RequestID) + require.Nil(t, upstream.lastReq, "previous_response_not_found 不应回退 HTTP") + require.Equal(t, int32(2), wsAttempts.Load(), "previous_response_not_found 应触发一次去掉 previous_response_id 的恢复重试") + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "resp_ws_prev_recover_ok", gjson.Get(rec.Body.String(), "id").String()) + + wsRequestMu.Lock() + requests := append([][]byte(nil), wsRequestPayloads...) + wsRequestMu.Unlock() + require.Len(t, requests, 2) + require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应保留 previous_response_id") + require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id") +} + +func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryForFunctionCallOutput(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + var wsRequestPayloads [][]byte + var wsRequestMu sync.Mutex + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + reqRaw, _ := json.Marshal(req) + wsRequestMu.Lock() + wsRequestPayloads = append(wsRequestPayloads, reqRaw) + wsRequestMu.Unlock() + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "previous_response_not_found", + "type": "invalid_request_error", + "message": "previous response not found", + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 92, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "previous_response_not_found 不应回退 HTTP") + require.Equal(t, int32(1), wsAttempts.Load(), "function_call_output 场景应跳过 previous_response_not_found 自动恢复") + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, strings.ToLower(rec.Body.String()), "previous response not found") + + wsRequestMu.Lock() + requests := append([][]byte(nil), wsRequestPayloads...) + wsRequestMu.Unlock() + require.Len(t, requests, 1) + require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists()) +} + +func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryWithoutPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + var wsRequestPayloads [][]byte + var wsRequestMu sync.Mutex + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + reqRaw, _ := json.Marshal(req) + wsRequestMu.Lock() + wsRequestPayloads = append(wsRequestPayloads, reqRaw) + wsRequestMu.Unlock() + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "previous_response_not_found", + "type": "invalid_request_error", + "message": "previous response not found", + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 93, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 模式下 previous_response_not_found 不应回退 HTTP") + require.Equal(t, int32(1), wsAttempts.Load(), "缺少 previous_response_id 时应跳过自动恢复重试") + require.Equal(t, http.StatusBadRequest, rec.Code) + + wsRequestMu.Lock() + requests := append([][]byte(nil), wsRequestPayloads...) + wsRequestMu.Unlock() + require.Len(t, requests, 1) + require.False(t, gjson.GetBytes(requests[0], "previous_response_id").Exists()) +} + +func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundOnlyRecoversOnce(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + var wsRequestPayloads [][]byte + var wsRequestMu sync.Mutex + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + reqRaw, _ := json.Marshal(req) + wsRequestMu.Lock() + wsRequestPayloads = append(wsRequestPayloads, reqRaw) + wsRequestMu.Unlock() + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "previous_response_not_found", + "type": "invalid_request_error", + "message": "previous response not found", + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 94, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 模式下 previous_response_not_found 不应回退 HTTP") + require.Equal(t, int32(2), wsAttempts.Load(), "应只允许一次自动恢复重试") + require.Equal(t, http.StatusBadRequest, rec.Code) + + wsRequestMu.Lock() + requests := append([][]byte(nil), wsRequestPayloads...) + wsRequestMu.Unlock() + require.Len(t, requests, 2) + require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应包含 previous_response_id") + require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id") +} diff --git a/backend/internal/service/openai_ws_protocol_resolver.go b/backend/internal/service/openai_ws_protocol_resolver.go new file mode 100644 index 00000000..368643be --- /dev/null +++ b/backend/internal/service/openai_ws_protocol_resolver.go @@ -0,0 +1,117 @@ +package service + +import "github.com/Wei-Shaw/sub2api/internal/config" + +// OpenAIUpstreamTransport 表示 OpenAI 上游传输协议。 +type OpenAIUpstreamTransport string + +const ( + OpenAIUpstreamTransportAny OpenAIUpstreamTransport = "" + OpenAIUpstreamTransportHTTPSSE OpenAIUpstreamTransport = "http_sse" + OpenAIUpstreamTransportResponsesWebsocket OpenAIUpstreamTransport = "responses_websockets" + OpenAIUpstreamTransportResponsesWebsocketV2 OpenAIUpstreamTransport = "responses_websockets_v2" +) + +// OpenAIWSProtocolDecision 表示协议决策结果。 +type OpenAIWSProtocolDecision struct { + Transport OpenAIUpstreamTransport + Reason string +} + +// OpenAIWSProtocolResolver 定义 OpenAI 上游协议决策。 +type OpenAIWSProtocolResolver interface { + Resolve(account *Account) OpenAIWSProtocolDecision +} + +type defaultOpenAIWSProtocolResolver struct { + cfg *config.Config +} + +// NewOpenAIWSProtocolResolver 创建默认协议决策器。 +func NewOpenAIWSProtocolResolver(cfg *config.Config) OpenAIWSProtocolResolver { + return &defaultOpenAIWSProtocolResolver{cfg: cfg} +} + +func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProtocolDecision { + if account == nil { + return openAIWSHTTPDecision("account_missing") + } + if !account.IsOpenAI() { + return openAIWSHTTPDecision("platform_not_openai") + } + if account.IsOpenAIWSForceHTTPEnabled() { + return openAIWSHTTPDecision("account_force_http") + } + if r == nil || r.cfg == nil { + return openAIWSHTTPDecision("config_missing") + } + + wsCfg := r.cfg.Gateway.OpenAIWS + if wsCfg.ForceHTTP { + return openAIWSHTTPDecision("global_force_http") + } + if !wsCfg.Enabled { + return openAIWSHTTPDecision("global_disabled") + } + if account.IsOpenAIOAuth() { + if !wsCfg.OAuthEnabled { + return openAIWSHTTPDecision("oauth_disabled") + } + } else if account.IsOpenAIApiKey() { + if !wsCfg.APIKeyEnabled { + return openAIWSHTTPDecision("apikey_disabled") + } + } else { + return openAIWSHTTPDecision("unknown_auth_type") + } + if wsCfg.ModeRouterV2Enabled { + mode := account.ResolveOpenAIResponsesWebSocketV2Mode(wsCfg.IngressModeDefault) + switch mode { + case OpenAIWSIngressModeOff: + return openAIWSHTTPDecision("account_mode_off") + case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated: + // continue + default: + return openAIWSHTTPDecision("account_mode_off") + } + if account.Concurrency <= 0 { + return openAIWSHTTPDecision("account_concurrency_invalid") + } + if wsCfg.ResponsesWebsocketsV2 { + return OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocketV2, + Reason: "ws_v2_mode_" + mode, + } + } + if wsCfg.ResponsesWebsockets { + return OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocket, + Reason: "ws_v1_mode_" + mode, + } + } + return openAIWSHTTPDecision("feature_disabled") + } + if !account.IsOpenAIResponsesWebSocketV2Enabled() { + return openAIWSHTTPDecision("account_disabled") + } + if wsCfg.ResponsesWebsocketsV2 { + return OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocketV2, + Reason: "ws_v2_enabled", + } + } + if wsCfg.ResponsesWebsockets { + return OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocket, + Reason: "ws_v1_enabled", + } + } + return openAIWSHTTPDecision("feature_disabled") +} + +func openAIWSHTTPDecision(reason string) OpenAIWSProtocolDecision { + return OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportHTTPSSE, + Reason: reason, + } +} diff --git a/backend/internal/service/openai_ws_protocol_resolver_test.go b/backend/internal/service/openai_ws_protocol_resolver_test.go new file mode 100644 index 00000000..5be76e28 --- /dev/null +++ b/backend/internal/service/openai_ws_protocol_resolver_test.go @@ -0,0 +1,203 @@ +package service + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAIWSProtocolResolver_Resolve(t *testing.T) { + baseCfg := &config.Config{} + baseCfg.Gateway.OpenAIWS.Enabled = true + baseCfg.Gateway.OpenAIWS.OAuthEnabled = true + baseCfg.Gateway.OpenAIWS.APIKeyEnabled = true + baseCfg.Gateway.OpenAIWS.ResponsesWebsockets = false + baseCfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + openAIOAuthEnabled := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + }, + } + + t.Run("v2优先", func(t *testing.T) { + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(openAIOAuthEnabled) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_enabled", decision.Reason) + }) + + t.Run("v2关闭时回退v1", func(t *testing.T) { + cfg := *baseCfg + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false + cfg.Gateway.OpenAIWS.ResponsesWebsockets = true + + decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocket, decision.Transport) + require.Equal(t, "ws_v1_enabled", decision.Reason) + }) + + t.Run("透传开关不影响WS协议判定", func(t *testing.T) { + account := *openAIOAuthEnabled + account.Extra = map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + "openai_passthrough": true, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_enabled", decision.Reason) + }) + + t.Run("账号级强制HTTP", func(t *testing.T) { + account := *openAIOAuthEnabled + account.Extra = map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + "openai_ws_force_http": true, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_force_http", decision.Reason) + }) + + t.Run("全局关闭保持HTTP", func(t *testing.T) { + cfg := *baseCfg + cfg.Gateway.OpenAIWS.Enabled = false + decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "global_disabled", decision.Reason) + }) + + t.Run("账号开关关闭保持HTTP", func(t *testing.T) { + account := *openAIOAuthEnabled + account.Extra = map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": false, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_disabled", decision.Reason) + }) + + t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) { + account := *openAIOAuthEnabled + account.Extra = map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_disabled", decision.Reason) + }) + + t.Run("兼容旧键openai_ws_enabled", func(t *testing.T) { + account := *openAIOAuthEnabled + account.Extra = map[string]any{ + "openai_ws_enabled": true, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_enabled", decision.Reason) + }) + + t.Run("按账号类型开关控制", func(t *testing.T) { + cfg := *baseCfg + cfg.Gateway.OpenAIWS.OAuthEnabled = false + decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "oauth_disabled", decision.Reason) + }) + + t.Run("API Key 账号关闭开关时回退HTTP", func(t *testing.T) { + cfg := *baseCfg + cfg.Gateway.OpenAIWS.APIKeyEnabled = false + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(account) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "apikey_disabled", decision.Reason) + }) + + t.Run("未知认证类型回退HTTP", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: "unknown_type", + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(account) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "unknown_auth_type", decision.Reason) + }) +} + +func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared + + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + }, + } + + t.Run("dedicated mode routes to ws v2", func(t *testing.T) { + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_mode_dedicated", decision.Reason) + }) + + t.Run("off mode routes to http", func(t *testing.T) { + offAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff, + }, + } + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(offAccount) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_mode_off", decision.Reason) + }) + + t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) { + legacyAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_mode_shared", decision.Reason) + }) + + t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) { + invalidConcurrency := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared, + }, + } + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_concurrency_invalid", decision.Reason) + }) +} diff --git a/backend/internal/service/openai_ws_state_store.go b/backend/internal/service/openai_ws_state_store.go new file mode 100644 index 00000000..b606baa1 --- /dev/null +++ b/backend/internal/service/openai_ws_state_store.go @@ -0,0 +1,440 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" +) + +const ( + openAIWSResponseAccountCachePrefix = "openai:response:" + openAIWSStateStoreCleanupInterval = time.Minute + openAIWSStateStoreCleanupMaxPerMap = 512 + openAIWSStateStoreMaxEntriesPerMap = 65536 + openAIWSStateStoreRedisTimeout = 3 * time.Second +) + +type openAIWSAccountBinding struct { + accountID int64 + expiresAt time.Time +} + +type openAIWSConnBinding struct { + connID string + expiresAt time.Time +} + +type openAIWSTurnStateBinding struct { + turnState string + expiresAt time.Time +} + +type openAIWSSessionConnBinding struct { + connID string + expiresAt time.Time +} + +// OpenAIWSStateStore 管理 WSv2 的粘连状态。 +// - response_id -> account_id 用于续链路由 +// - response_id -> conn_id 用于连接内上下文复用 +// +// response_id -> account_id 优先走 GatewayCache(Redis),同时维护本地热缓存。 +// response_id -> conn_id 仅在本进程内有效。 +type OpenAIWSStateStore interface { + BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error + GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error) + DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error + + BindResponseConn(responseID, connID string, ttl time.Duration) + GetResponseConn(responseID string) (string, bool) + DeleteResponseConn(responseID string) + + BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) + GetSessionTurnState(groupID int64, sessionHash string) (string, bool) + DeleteSessionTurnState(groupID int64, sessionHash string) + + BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) + GetSessionConn(groupID int64, sessionHash string) (string, bool) + DeleteSessionConn(groupID int64, sessionHash string) +} + +type defaultOpenAIWSStateStore struct { + cache GatewayCache + + responseToAccountMu sync.RWMutex + responseToAccount map[string]openAIWSAccountBinding + responseToConnMu sync.RWMutex + responseToConn map[string]openAIWSConnBinding + sessionToTurnStateMu sync.RWMutex + sessionToTurnState map[string]openAIWSTurnStateBinding + sessionToConnMu sync.RWMutex + sessionToConn map[string]openAIWSSessionConnBinding + + lastCleanupUnixNano atomic.Int64 +} + +// NewOpenAIWSStateStore 创建默认 WS 状态存储。 +func NewOpenAIWSStateStore(cache GatewayCache) OpenAIWSStateStore { + store := &defaultOpenAIWSStateStore{ + cache: cache, + responseToAccount: make(map[string]openAIWSAccountBinding, 256), + responseToConn: make(map[string]openAIWSConnBinding, 256), + sessionToTurnState: make(map[string]openAIWSTurnStateBinding, 256), + sessionToConn: make(map[string]openAIWSSessionConnBinding, 256), + } + store.lastCleanupUnixNano.Store(time.Now().UnixNano()) + return store +} + +func (s *defaultOpenAIWSStateStore) BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" || accountID <= 0 { + return nil + } + ttl = normalizeOpenAIWSTTL(ttl) + s.maybeCleanup() + + expiresAt := time.Now().Add(ttl) + s.responseToAccountMu.Lock() + ensureBindingCapacity(s.responseToAccount, id, openAIWSStateStoreMaxEntriesPerMap) + s.responseToAccount[id] = openAIWSAccountBinding{accountID: accountID, expiresAt: expiresAt} + s.responseToAccountMu.Unlock() + + if s.cache == nil { + return nil + } + cacheKey := openAIWSResponseAccountCacheKey(id) + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx) + defer cancel() + return s.cache.SetSessionAccountID(cacheCtx, groupID, cacheKey, accountID, ttl) +} + +func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error) { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return 0, nil + } + s.maybeCleanup() + + now := time.Now() + s.responseToAccountMu.RLock() + if binding, ok := s.responseToAccount[id]; ok { + if now.Before(binding.expiresAt) { + accountID := binding.accountID + s.responseToAccountMu.RUnlock() + return accountID, nil + } + } + s.responseToAccountMu.RUnlock() + + if s.cache == nil { + return 0, nil + } + + cacheKey := openAIWSResponseAccountCacheKey(id) + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx) + defer cancel() + accountID, err := s.cache.GetSessionAccountID(cacheCtx, groupID, cacheKey) + if err != nil || accountID <= 0 { + // 缓存读取失败不阻断主流程,按未命中降级。 + return 0, nil + } + return accountID, nil +} + +func (s *defaultOpenAIWSStateStore) DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return nil + } + s.responseToAccountMu.Lock() + delete(s.responseToAccount, id) + s.responseToAccountMu.Unlock() + + if s.cache == nil { + return nil + } + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx) + defer cancel() + return s.cache.DeleteSessionAccountID(cacheCtx, groupID, openAIWSResponseAccountCacheKey(id)) +} + +func (s *defaultOpenAIWSStateStore) BindResponseConn(responseID, connID string, ttl time.Duration) { + id := normalizeOpenAIWSResponseID(responseID) + conn := strings.TrimSpace(connID) + if id == "" || conn == "" { + return + } + ttl = normalizeOpenAIWSTTL(ttl) + s.maybeCleanup() + + s.responseToConnMu.Lock() + ensureBindingCapacity(s.responseToConn, id, openAIWSStateStoreMaxEntriesPerMap) + s.responseToConn[id] = openAIWSConnBinding{ + connID: conn, + expiresAt: time.Now().Add(ttl), + } + s.responseToConnMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) GetResponseConn(responseID string) (string, bool) { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return "", false + } + s.maybeCleanup() + + now := time.Now() + s.responseToConnMu.RLock() + binding, ok := s.responseToConn[id] + s.responseToConnMu.RUnlock() + if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" { + return "", false + } + return binding.connID, true +} + +func (s *defaultOpenAIWSStateStore) DeleteResponseConn(responseID string) { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return + } + s.responseToConnMu.Lock() + delete(s.responseToConn, id) + s.responseToConnMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + state := strings.TrimSpace(turnState) + if key == "" || state == "" { + return + } + ttl = normalizeOpenAIWSTTL(ttl) + s.maybeCleanup() + + s.sessionToTurnStateMu.Lock() + ensureBindingCapacity(s.sessionToTurnState, key, openAIWSStateStoreMaxEntriesPerMap) + s.sessionToTurnState[key] = openAIWSTurnStateBinding{ + turnState: state, + expiresAt: time.Now().Add(ttl), + } + s.sessionToTurnStateMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) GetSessionTurnState(groupID int64, sessionHash string) (string, bool) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + if key == "" { + return "", false + } + s.maybeCleanup() + + now := time.Now() + s.sessionToTurnStateMu.RLock() + binding, ok := s.sessionToTurnState[key] + s.sessionToTurnStateMu.RUnlock() + if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.turnState) == "" { + return "", false + } + return binding.turnState, true +} + +func (s *defaultOpenAIWSStateStore) DeleteSessionTurnState(groupID int64, sessionHash string) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + if key == "" { + return + } + s.sessionToTurnStateMu.Lock() + delete(s.sessionToTurnState, key) + s.sessionToTurnStateMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + conn := strings.TrimSpace(connID) + if key == "" || conn == "" { + return + } + ttl = normalizeOpenAIWSTTL(ttl) + s.maybeCleanup() + + s.sessionToConnMu.Lock() + ensureBindingCapacity(s.sessionToConn, key, openAIWSStateStoreMaxEntriesPerMap) + s.sessionToConn[key] = openAIWSSessionConnBinding{ + connID: conn, + expiresAt: time.Now().Add(ttl), + } + s.sessionToConnMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) GetSessionConn(groupID int64, sessionHash string) (string, bool) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + if key == "" { + return "", false + } + s.maybeCleanup() + + now := time.Now() + s.sessionToConnMu.RLock() + binding, ok := s.sessionToConn[key] + s.sessionToConnMu.RUnlock() + if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" { + return "", false + } + return binding.connID, true +} + +func (s *defaultOpenAIWSStateStore) DeleteSessionConn(groupID int64, sessionHash string) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + if key == "" { + return + } + s.sessionToConnMu.Lock() + delete(s.sessionToConn, key) + s.sessionToConnMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) maybeCleanup() { + if s == nil { + return + } + now := time.Now() + last := time.Unix(0, s.lastCleanupUnixNano.Load()) + if now.Sub(last) < openAIWSStateStoreCleanupInterval { + return + } + if !s.lastCleanupUnixNano.CompareAndSwap(last.UnixNano(), now.UnixNano()) { + return + } + + // 增量限额清理,避免高规模下一次性全量扫描导致长时间阻塞。 + s.responseToAccountMu.Lock() + cleanupExpiredAccountBindings(s.responseToAccount, now, openAIWSStateStoreCleanupMaxPerMap) + s.responseToAccountMu.Unlock() + + s.responseToConnMu.Lock() + cleanupExpiredConnBindings(s.responseToConn, now, openAIWSStateStoreCleanupMaxPerMap) + s.responseToConnMu.Unlock() + + s.sessionToTurnStateMu.Lock() + cleanupExpiredTurnStateBindings(s.sessionToTurnState, now, openAIWSStateStoreCleanupMaxPerMap) + s.sessionToTurnStateMu.Unlock() + + s.sessionToConnMu.Lock() + cleanupExpiredSessionConnBindings(s.sessionToConn, now, openAIWSStateStoreCleanupMaxPerMap) + s.sessionToConnMu.Unlock() +} + +func cleanupExpiredAccountBindings(bindings map[string]openAIWSAccountBinding, now time.Time, maxScan int) { + if len(bindings) == 0 || maxScan <= 0 { + return + } + scanned := 0 + for key, binding := range bindings { + if now.After(binding.expiresAt) { + delete(bindings, key) + } + scanned++ + if scanned >= maxScan { + break + } + } +} + +func cleanupExpiredConnBindings(bindings map[string]openAIWSConnBinding, now time.Time, maxScan int) { + if len(bindings) == 0 || maxScan <= 0 { + return + } + scanned := 0 + for key, binding := range bindings { + if now.After(binding.expiresAt) { + delete(bindings, key) + } + scanned++ + if scanned >= maxScan { + break + } + } +} + +func cleanupExpiredTurnStateBindings(bindings map[string]openAIWSTurnStateBinding, now time.Time, maxScan int) { + if len(bindings) == 0 || maxScan <= 0 { + return + } + scanned := 0 + for key, binding := range bindings { + if now.After(binding.expiresAt) { + delete(bindings, key) + } + scanned++ + if scanned >= maxScan { + break + } + } +} + +func cleanupExpiredSessionConnBindings(bindings map[string]openAIWSSessionConnBinding, now time.Time, maxScan int) { + if len(bindings) == 0 || maxScan <= 0 { + return + } + scanned := 0 + for key, binding := range bindings { + if now.After(binding.expiresAt) { + delete(bindings, key) + } + scanned++ + if scanned >= maxScan { + break + } + } +} + +func ensureBindingCapacity[T any](bindings map[string]T, incomingKey string, maxEntries int) { + if len(bindings) < maxEntries || maxEntries <= 0 { + return + } + if _, exists := bindings[incomingKey]; exists { + return + } + // 固定上限保护:淘汰任意一项,优先保证内存有界。 + for key := range bindings { + delete(bindings, key) + return + } +} + +func normalizeOpenAIWSResponseID(responseID string) string { + return strings.TrimSpace(responseID) +} + +func openAIWSResponseAccountCacheKey(responseID string) string { + sum := sha256.Sum256([]byte(responseID)) + return openAIWSResponseAccountCachePrefix + hex.EncodeToString(sum[:]) +} + +func normalizeOpenAIWSTTL(ttl time.Duration) time.Duration { + if ttl <= 0 { + return time.Hour + } + return ttl +} + +func openAIWSSessionTurnStateKey(groupID int64, sessionHash string) string { + hash := strings.TrimSpace(sessionHash) + if hash == "" { + return "" + } + return fmt.Sprintf("%d:%s", groupID, hash) +} + +func withOpenAIWSStateStoreRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) { + if ctx == nil { + ctx = context.Background() + } + return context.WithTimeout(ctx, openAIWSStateStoreRedisTimeout) +} diff --git a/backend/internal/service/openai_ws_state_store_test.go b/backend/internal/service/openai_ws_state_store_test.go new file mode 100644 index 00000000..235d4233 --- /dev/null +++ b/backend/internal/service/openai_ws_state_store_test.go @@ -0,0 +1,235 @@ +package service + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestOpenAIWSStateStore_BindGetDeleteResponseAccount(t *testing.T) { + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + ctx := context.Background() + groupID := int64(7) + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_abc", 101, time.Minute)) + + accountID, err := store.GetResponseAccount(ctx, groupID, "resp_abc") + require.NoError(t, err) + require.Equal(t, int64(101), accountID) + + require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_abc")) + accountID, err = store.GetResponseAccount(ctx, groupID, "resp_abc") + require.NoError(t, err) + require.Zero(t, accountID) +} + +func TestOpenAIWSStateStore_ResponseConnTTL(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + store.BindResponseConn("resp_conn", "conn_1", 30*time.Millisecond) + + connID, ok := store.GetResponseConn("resp_conn") + require.True(t, ok) + require.Equal(t, "conn_1", connID) + + time.Sleep(60 * time.Millisecond) + _, ok = store.GetResponseConn("resp_conn") + require.False(t, ok) +} + +func TestOpenAIWSStateStore_SessionTurnStateTTL(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + store.BindSessionTurnState(9, "session_hash_1", "turn_state_1", 30*time.Millisecond) + + state, ok := store.GetSessionTurnState(9, "session_hash_1") + require.True(t, ok) + require.Equal(t, "turn_state_1", state) + + // group 隔离 + _, ok = store.GetSessionTurnState(10, "session_hash_1") + require.False(t, ok) + + time.Sleep(60 * time.Millisecond) + _, ok = store.GetSessionTurnState(9, "session_hash_1") + require.False(t, ok) +} + +func TestOpenAIWSStateStore_SessionConnTTL(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + store.BindSessionConn(9, "session_hash_conn_1", "conn_1", 30*time.Millisecond) + + connID, ok := store.GetSessionConn(9, "session_hash_conn_1") + require.True(t, ok) + require.Equal(t, "conn_1", connID) + + // group 隔离 + _, ok = store.GetSessionConn(10, "session_hash_conn_1") + require.False(t, ok) + + time.Sleep(60 * time.Millisecond) + _, ok = store.GetSessionConn(9, "session_hash_conn_1") + require.False(t, ok) +} + +func TestOpenAIWSStateStore_GetResponseAccount_NoStaleAfterCacheMiss(t *testing.T) { + cache := &stubGatewayCache{sessionBindings: map[string]int64{}} + store := NewOpenAIWSStateStore(cache) + ctx := context.Background() + groupID := int64(17) + responseID := "resp_cache_stale" + cacheKey := openAIWSResponseAccountCacheKey(responseID) + + cache.sessionBindings[cacheKey] = 501 + accountID, err := store.GetResponseAccount(ctx, groupID, responseID) + require.NoError(t, err) + require.Equal(t, int64(501), accountID) + + delete(cache.sessionBindings, cacheKey) + accountID, err = store.GetResponseAccount(ctx, groupID, responseID) + require.NoError(t, err) + require.Zero(t, accountID, "上游缓存失效后不应继续命中本地陈旧映射") +} + +func TestOpenAIWSStateStore_MaybeCleanupRemovesExpiredIncrementally(t *testing.T) { + raw := NewOpenAIWSStateStore(nil) + store, ok := raw.(*defaultOpenAIWSStateStore) + require.True(t, ok) + + expiredAt := time.Now().Add(-time.Minute) + total := 2048 + store.responseToConnMu.Lock() + for i := 0; i < total; i++ { + store.responseToConn[fmt.Sprintf("resp_%d", i)] = openAIWSConnBinding{ + connID: "conn_incremental", + expiresAt: expiredAt, + } + } + store.responseToConnMu.Unlock() + + store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano()) + store.maybeCleanup() + + store.responseToConnMu.RLock() + remainingAfterFirst := len(store.responseToConn) + store.responseToConnMu.RUnlock() + require.Less(t, remainingAfterFirst, total, "单轮 cleanup 应至少有进展") + require.Greater(t, remainingAfterFirst, 0, "增量清理不要求单轮清空全部键") + + for i := 0; i < 8; i++ { + store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano()) + store.maybeCleanup() + } + + store.responseToConnMu.RLock() + remaining := len(store.responseToConn) + store.responseToConnMu.RUnlock() + require.Zero(t, remaining, "多轮 cleanup 后应逐步清空全部过期键") +} + +func TestEnsureBindingCapacity_EvictsOneWhenMapIsFull(t *testing.T) { + bindings := map[string]int{ + "a": 1, + "b": 2, + } + + ensureBindingCapacity(bindings, "c", 2) + bindings["c"] = 3 + + require.Len(t, bindings, 2) + require.Equal(t, 3, bindings["c"]) +} + +func TestEnsureBindingCapacity_DoesNotEvictWhenUpdatingExistingKey(t *testing.T) { + bindings := map[string]int{ + "a": 1, + "b": 2, + } + + ensureBindingCapacity(bindings, "a", 2) + bindings["a"] = 9 + + require.Len(t, bindings, 2) + require.Equal(t, 9, bindings["a"]) +} + +type openAIWSStateStoreTimeoutProbeCache struct { + setHasDeadline bool + getHasDeadline bool + deleteHasDeadline bool + setDeadlineDelta time.Duration + getDeadlineDelta time.Duration + delDeadlineDelta time.Duration +} + +func (c *openAIWSStateStoreTimeoutProbeCache) GetSessionAccountID(ctx context.Context, _ int64, _ string) (int64, error) { + if deadline, ok := ctx.Deadline(); ok { + c.getHasDeadline = true + c.getDeadlineDelta = time.Until(deadline) + } + return 123, nil +} + +func (c *openAIWSStateStoreTimeoutProbeCache) SetSessionAccountID(ctx context.Context, _ int64, _ string, _ int64, _ time.Duration) error { + if deadline, ok := ctx.Deadline(); ok { + c.setHasDeadline = true + c.setDeadlineDelta = time.Until(deadline) + } + return errors.New("set failed") +} + +func (c *openAIWSStateStoreTimeoutProbeCache) RefreshSessionTTL(context.Context, int64, string, time.Duration) error { + return nil +} + +func (c *openAIWSStateStoreTimeoutProbeCache) DeleteSessionAccountID(ctx context.Context, _ int64, _ string) error { + if deadline, ok := ctx.Deadline(); ok { + c.deleteHasDeadline = true + c.delDeadlineDelta = time.Until(deadline) + } + return nil +} + +func TestOpenAIWSStateStore_RedisOpsUseShortTimeout(t *testing.T) { + probe := &openAIWSStateStoreTimeoutProbeCache{} + store := NewOpenAIWSStateStore(probe) + ctx := context.Background() + groupID := int64(5) + + err := store.BindResponseAccount(ctx, groupID, "resp_timeout_probe", 11, time.Minute) + require.Error(t, err) + + accountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_timeout_probe") + require.NoError(t, getErr) + require.Equal(t, int64(11), accountID, "本地缓存命中应优先返回已绑定账号") + + require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_timeout_probe")) + + require.True(t, probe.setHasDeadline, "SetSessionAccountID 应携带独立超时上下文") + require.True(t, probe.deleteHasDeadline, "DeleteSessionAccountID 应携带独立超时上下文") + require.False(t, probe.getHasDeadline, "GetSessionAccountID 本用例应由本地缓存命中,不触发 Redis 读取") + require.Greater(t, probe.setDeadlineDelta, 2*time.Second) + require.LessOrEqual(t, probe.setDeadlineDelta, 3*time.Second) + require.Greater(t, probe.delDeadlineDelta, 2*time.Second) + require.LessOrEqual(t, probe.delDeadlineDelta, 3*time.Second) + + probe2 := &openAIWSStateStoreTimeoutProbeCache{} + store2 := NewOpenAIWSStateStore(probe2) + accountID2, err2 := store2.GetResponseAccount(ctx, groupID, "resp_cache_only") + require.NoError(t, err2) + require.Equal(t, int64(123), accountID2) + require.True(t, probe2.getHasDeadline, "GetSessionAccountID 在缓存未命中时应携带独立超时上下文") + require.Greater(t, probe2.getDeadlineDelta, 2*time.Second) + require.LessOrEqual(t, probe2.getDeadlineDelta, 3*time.Second) +} + +func TestWithOpenAIWSStateStoreRedisTimeout_WithParentContext(t *testing.T) { + ctx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background()) + defer cancel() + require.NotNil(t, ctx) + _, ok := ctx.Deadline() + require.True(t, ok, "应附加短超时") +} diff --git a/backend/internal/service/ops_aggregation_service.go b/backend/internal/service/ops_aggregation_service.go index 972462ec..ec77fe12 100644 --- a/backend/internal/service/ops_aggregation_service.go +++ b/backend/internal/service/ops_aggregation_service.go @@ -5,12 +5,12 @@ import ( "database/sql" "errors" "fmt" - "log" "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/google/uuid" "github.com/redis/go-redis/v9" ) @@ -190,7 +190,7 @@ func (s *OpsAggregationService) aggregateHourly() { latest, ok, err := s.opsRepo.GetLatestHourlyBucketStart(ctxMax) cancelMax() if err != nil { - log.Printf("[OpsAggregation][hourly] failed to read latest bucket: %v", err) + logger.LegacyPrintf("service.ops_aggregation", "[OpsAggregation][hourly] failed to read latest bucket: %v", err) } else if ok { candidate := latest.Add(-opsAggHourlyOverlap) if candidate.After(start) { @@ -209,7 +209,7 @@ func (s *OpsAggregationService) aggregateHourly() { chunkEnd := minTime(cursor.Add(opsAggHourlyChunk), end) if err := s.opsRepo.UpsertHourlyMetrics(ctx, cursor, chunkEnd); err != nil { aggErr = err - log.Printf("[OpsAggregation][hourly] upsert failed (%s..%s): %v", cursor.Format(time.RFC3339), chunkEnd.Format(time.RFC3339), err) + logger.LegacyPrintf("service.ops_aggregation", "[OpsAggregation][hourly] upsert failed (%s..%s): %v", cursor.Format(time.RFC3339), chunkEnd.Format(time.RFC3339), err) break } } @@ -288,7 +288,7 @@ func (s *OpsAggregationService) aggregateDaily() { latest, ok, err := s.opsRepo.GetLatestDailyBucketDate(ctxMax) cancelMax() if err != nil { - log.Printf("[OpsAggregation][daily] failed to read latest bucket: %v", err) + logger.LegacyPrintf("service.ops_aggregation", "[OpsAggregation][daily] failed to read latest bucket: %v", err) } else if ok { candidate := latest.Add(-opsAggDailyOverlap) if candidate.After(start) { @@ -307,7 +307,7 @@ func (s *OpsAggregationService) aggregateDaily() { chunkEnd := minTime(cursor.Add(opsAggDailyChunk), end) if err := s.opsRepo.UpsertDailyMetrics(ctx, cursor, chunkEnd); err != nil { aggErr = err - log.Printf("[OpsAggregation][daily] upsert failed (%s..%s): %v", cursor.Format("2006-01-02"), chunkEnd.Format("2006-01-02"), err) + logger.LegacyPrintf("service.ops_aggregation", "[OpsAggregation][daily] upsert failed (%s..%s): %v", cursor.Format("2006-01-02"), chunkEnd.Format("2006-01-02"), err) break } } @@ -427,7 +427,7 @@ func (s *OpsAggregationService) maybeLogSkip(prefix string) { if prefix == "" { prefix = "[OpsAggregation]" } - log.Printf("%s leader lock held by another instance; skipping", prefix) + logger.LegacyPrintf("service.ops_aggregation", "%s leader lock held by another instance; skipping", prefix) } func utcFloorToHour(t time.Time) time.Time { diff --git a/backend/internal/service/ops_alert_evaluator_service.go b/backend/internal/service/ops_alert_evaluator_service.go index 7c62e247..169a5e32 100644 --- a/backend/internal/service/ops_alert_evaluator_service.go +++ b/backend/internal/service/ops_alert_evaluator_service.go @@ -3,7 +3,6 @@ package service import ( "context" "fmt" - "log" "math" "strconv" "strings" @@ -11,6 +10,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/google/uuid" "github.com/redis/go-redis/v9" ) @@ -186,7 +186,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) { rules, err := s.opsRepo.ListAlertRules(ctx) if err != nil { s.recordHeartbeatError(runAt, time.Since(startedAt), err) - log.Printf("[OpsAlertEvaluator] list rules failed: %v", err) + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] list rules failed: %v", err) return } @@ -236,7 +236,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) { activeEvent, err := s.opsRepo.GetActiveAlertEvent(ctx, rule.ID) if err != nil { - log.Printf("[OpsAlertEvaluator] get active event failed (rule=%d): %v", rule.ID, err) + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] get active event failed (rule=%d): %v", rule.ID, err) continue } @@ -258,7 +258,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) { latestEvent, err := s.opsRepo.GetLatestAlertEvent(ctx, rule.ID) if err != nil { - log.Printf("[OpsAlertEvaluator] get latest event failed (rule=%d): %v", rule.ID, err) + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] get latest event failed (rule=%d): %v", rule.ID, err) continue } if latestEvent != nil && rule.CooldownMinutes > 0 { @@ -283,7 +283,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) { created, err := s.opsRepo.CreateAlertEvent(ctx, firedEvent) if err != nil { - log.Printf("[OpsAlertEvaluator] create event failed (rule=%d): %v", rule.ID, err) + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] create event failed (rule=%d): %v", rule.ID, err) continue } @@ -300,7 +300,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) { if activeEvent != nil { resolvedAt := now if err := s.opsRepo.UpdateAlertEventStatus(ctx, activeEvent.ID, OpsAlertStatusResolved, &resolvedAt); err != nil { - log.Printf("[OpsAlertEvaluator] resolve event failed (event=%d): %v", activeEvent.ID, err) + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] resolve event failed (event=%d): %v", activeEvent.ID, err) } else { eventsResolved++ } @@ -779,7 +779,7 @@ func (s *OpsAlertEvaluatorService) tryAcquireLeaderLock(ctx context.Context, loc } if s.redisClient == nil { s.warnNoRedisOnce.Do(func() { - log.Printf("[OpsAlertEvaluator] redis not configured; running without distributed lock") + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] redis not configured; running without distributed lock") }) return nil, true } @@ -797,7 +797,7 @@ func (s *OpsAlertEvaluatorService) tryAcquireLeaderLock(ctx context.Context, loc // Prefer fail-closed to avoid duplicate evaluators stampeding the DB when Redis is flaky. // Single-node deployments can disable the distributed lock via runtime settings. s.warnNoRedisOnce.Do(func() { - log.Printf("[OpsAlertEvaluator] leader lock SetNX failed; skipping this cycle: %v", err) + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] leader lock SetNX failed; skipping this cycle: %v", err) }) return nil, false } @@ -819,7 +819,7 @@ func (s *OpsAlertEvaluatorService) maybeLogSkip(key string) { return } s.skipLogAt = now - log.Printf("[OpsAlertEvaluator] leader lock held by another instance; skipping (key=%q)", key) + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] leader lock held by another instance; skipping (key=%q)", key) } func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, result string) { diff --git a/backend/internal/service/ops_alert_evaluator_service_test.go b/backend/internal/service/ops_alert_evaluator_service_test.go index 068ab6bb..83d358a3 100644 --- a/backend/internal/service/ops_alert_evaluator_service_test.go +++ b/backend/internal/service/ops_alert_evaluator_service_test.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/require" ) +var _ OpsRepository = (*stubOpsRepo)(nil) + type stubOpsRepo struct { OpsRepository overview *OpsDashboardOverview diff --git a/backend/internal/service/ops_cleanup_service.go b/backend/internal/service/ops_cleanup_service.go index 1ade7176..1cae6fe5 100644 --- a/backend/internal/service/ops_cleanup_service.go +++ b/backend/internal/service/ops_cleanup_service.go @@ -4,12 +4,12 @@ import ( "context" "database/sql" "fmt" - "log" "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/google/uuid" "github.com/redis/go-redis/v9" "github.com/robfig/cron/v3" @@ -75,11 +75,11 @@ func (s *OpsCleanupService) Start() { return } if s.cfg != nil && !s.cfg.Ops.Cleanup.Enabled { - log.Printf("[OpsCleanup] not started (disabled)") + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] not started (disabled)") return } if s.opsRepo == nil || s.db == nil { - log.Printf("[OpsCleanup] not started (missing deps)") + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] not started (missing deps)") return } @@ -99,12 +99,12 @@ func (s *OpsCleanupService) Start() { c := cron.New(cron.WithParser(opsCleanupCronParser), cron.WithLocation(loc)) _, err := c.AddFunc(schedule, func() { s.runScheduled() }) if err != nil { - log.Printf("[OpsCleanup] not started (invalid schedule=%q): %v", schedule, err) + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] not started (invalid schedule=%q): %v", schedule, err) return } s.cron = c s.cron.Start() - log.Printf("[OpsCleanup] started (schedule=%q tz=%s)", schedule, loc.String()) + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] started (schedule=%q tz=%s)", schedule, loc.String()) }) } @@ -118,7 +118,7 @@ func (s *OpsCleanupService) Stop() { select { case <-ctx.Done(): case <-time.After(3 * time.Second): - log.Printf("[OpsCleanup] cron stop timed out") + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] cron stop timed out") } } }) @@ -146,17 +146,19 @@ func (s *OpsCleanupService) runScheduled() { counts, err := s.runCleanupOnce(ctx) if err != nil { s.recordHeartbeatError(runAt, time.Since(startedAt), err) - log.Printf("[OpsCleanup] cleanup failed: %v", err) + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] cleanup failed: %v", err) return } s.recordHeartbeatSuccess(runAt, time.Since(startedAt), counts) - log.Printf("[OpsCleanup] cleanup complete: %s", counts) + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] cleanup complete: %s", counts) } type opsCleanupDeletedCounts struct { errorLogs int64 retryAttempts int64 alertEvents int64 + systemLogs int64 + logAudits int64 systemMetrics int64 hourlyPreagg int64 dailyPreagg int64 @@ -164,10 +166,12 @@ type opsCleanupDeletedCounts struct { func (c opsCleanupDeletedCounts) String() string { return fmt.Sprintf( - "error_logs=%d retry_attempts=%d alert_events=%d system_metrics=%d hourly_preagg=%d daily_preagg=%d", + "error_logs=%d retry_attempts=%d alert_events=%d system_logs=%d log_audits=%d system_metrics=%d hourly_preagg=%d daily_preagg=%d", c.errorLogs, c.retryAttempts, c.alertEvents, + c.systemLogs, + c.logAudits, c.systemMetrics, c.hourlyPreagg, c.dailyPreagg, @@ -204,6 +208,18 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet return out, err } out.alertEvents = n + + n, err = deleteOldRowsByID(ctx, s.db, "ops_system_logs", "created_at", cutoff, batchSize, false) + if err != nil { + return out, err + } + out.systemLogs = n + + n, err = deleteOldRowsByID(ctx, s.db, "ops_system_log_cleanup_audits", "created_at", cutoff, batchSize, false) + if err != nil { + return out, err + } + out.logAudits = n } // Minute-level metrics snapshots. @@ -315,11 +331,11 @@ func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), b } // Redis error: fall back to DB advisory lock. s.warnNoRedisOnce.Do(func() { - log.Printf("[OpsCleanup] leader lock SetNX failed; falling back to DB advisory lock: %v", err) + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] leader lock SetNX failed; falling back to DB advisory lock: %v", err) }) } else { s.warnNoRedisOnce.Do(func() { - log.Printf("[OpsCleanup] redis not configured; using DB advisory lock") + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] redis not configured; using DB advisory lock") }) } diff --git a/backend/internal/service/ops_dashboard.go b/backend/internal/service/ops_dashboard.go index 31822ba8..6f70c75c 100644 --- a/backend/internal/service/ops_dashboard.go +++ b/backend/internal/service/ops_dashboard.go @@ -31,6 +31,10 @@ func (s *OpsService) GetDashboardOverview(ctx context.Context, filter *OpsDashbo filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode) overview, err := s.opsRepo.GetDashboardOverview(ctx, filter) + if err != nil && shouldFallbackOpsPreagg(filter, err) { + rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw) + overview, err = s.opsRepo.GetDashboardOverview(ctx, rawFilter) + } if err != nil { if errors.Is(err, ErrOpsPreaggregatedNotPopulated) { return nil, infraerrors.Conflict("OPS_PREAGG_NOT_READY", "Pre-aggregated ops metrics are not populated yet") diff --git a/backend/internal/service/ops_errors.go b/backend/internal/service/ops_errors.go index 76b5ce8b..01671c1e 100644 --- a/backend/internal/service/ops_errors.go +++ b/backend/internal/service/ops_errors.go @@ -22,7 +22,14 @@ func (s *OpsService) GetErrorTrend(ctx context.Context, filter *OpsDashboardFilt if filter.StartTime.After(filter.EndTime) { return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time") } - return s.opsRepo.GetErrorTrend(ctx, filter, bucketSeconds) + filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode) + + result, err := s.opsRepo.GetErrorTrend(ctx, filter, bucketSeconds) + if err != nil && shouldFallbackOpsPreagg(filter, err) { + rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw) + return s.opsRepo.GetErrorTrend(ctx, rawFilter, bucketSeconds) + } + return result, err } func (s *OpsService) GetErrorDistribution(ctx context.Context, filter *OpsDashboardFilter) (*OpsErrorDistributionResponse, error) { @@ -41,5 +48,12 @@ func (s *OpsService) GetErrorDistribution(ctx context.Context, filter *OpsDashbo if filter.StartTime.After(filter.EndTime) { return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time") } - return s.opsRepo.GetErrorDistribution(ctx, filter) + filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode) + + result, err := s.opsRepo.GetErrorDistribution(ctx, filter) + if err != nil && shouldFallbackOpsPreagg(filter, err) { + rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw) + return s.opsRepo.GetErrorDistribution(ctx, rawFilter) + } + return result, err } diff --git a/backend/internal/service/ops_histograms.go b/backend/internal/service/ops_histograms.go index 9f5b514f..c555dbfc 100644 --- a/backend/internal/service/ops_histograms.go +++ b/backend/internal/service/ops_histograms.go @@ -22,5 +22,12 @@ func (s *OpsService) GetLatencyHistogram(ctx context.Context, filter *OpsDashboa if filter.StartTime.After(filter.EndTime) { return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time") } - return s.opsRepo.GetLatencyHistogram(ctx, filter) + filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode) + + result, err := s.opsRepo.GetLatencyHistogram(ctx, filter) + if err != nil && shouldFallbackOpsPreagg(filter, err) { + rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw) + return s.opsRepo.GetLatencyHistogram(ctx, rawFilter) + } + return result, err } diff --git a/backend/internal/service/ops_log_runtime.go b/backend/internal/service/ops_log_runtime.go new file mode 100644 index 00000000..ed8aefa9 --- /dev/null +++ b/backend/internal/service/ops_log_runtime.go @@ -0,0 +1,267 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "go.uber.org/zap" +) + +func defaultOpsRuntimeLogConfig(cfg *config.Config) *OpsRuntimeLogConfig { + out := &OpsRuntimeLogConfig{ + Level: "info", + EnableSampling: false, + SamplingInitial: 100, + SamplingNext: 100, + Caller: true, + StacktraceLevel: "error", + RetentionDays: 30, + } + if cfg == nil { + return out + } + out.Level = strings.ToLower(strings.TrimSpace(cfg.Log.Level)) + out.EnableSampling = cfg.Log.Sampling.Enabled + out.SamplingInitial = cfg.Log.Sampling.Initial + out.SamplingNext = cfg.Log.Sampling.Thereafter + out.Caller = cfg.Log.Caller + out.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel)) + if cfg.Ops.Cleanup.ErrorLogRetentionDays > 0 { + out.RetentionDays = cfg.Ops.Cleanup.ErrorLogRetentionDays + } + return out +} + +func normalizeOpsRuntimeLogConfig(cfg *OpsRuntimeLogConfig, defaults *OpsRuntimeLogConfig) { + if cfg == nil || defaults == nil { + return + } + cfg.Level = strings.ToLower(strings.TrimSpace(cfg.Level)) + if cfg.Level == "" { + cfg.Level = defaults.Level + } + cfg.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.StacktraceLevel)) + if cfg.StacktraceLevel == "" { + cfg.StacktraceLevel = defaults.StacktraceLevel + } + if cfg.SamplingInitial <= 0 { + cfg.SamplingInitial = defaults.SamplingInitial + } + if cfg.SamplingNext <= 0 { + cfg.SamplingNext = defaults.SamplingNext + } + if cfg.RetentionDays <= 0 { + cfg.RetentionDays = defaults.RetentionDays + } +} + +func validateOpsRuntimeLogConfig(cfg *OpsRuntimeLogConfig) error { + if cfg == nil { + return errors.New("invalid config") + } + switch strings.ToLower(strings.TrimSpace(cfg.Level)) { + case "debug", "info", "warn", "error": + default: + return errors.New("level must be one of: debug/info/warn/error") + } + switch strings.ToLower(strings.TrimSpace(cfg.StacktraceLevel)) { + case "none", "error", "fatal": + default: + return errors.New("stacktrace_level must be one of: none/error/fatal") + } + if cfg.SamplingInitial <= 0 { + return errors.New("sampling_initial must be positive") + } + if cfg.SamplingNext <= 0 { + return errors.New("sampling_thereafter must be positive") + } + if cfg.RetentionDays < 1 || cfg.RetentionDays > 3650 { + return errors.New("retention_days must be between 1 and 3650") + } + return nil +} + +func (s *OpsService) GetRuntimeLogConfig(ctx context.Context) (*OpsRuntimeLogConfig, error) { + if s == nil || s.settingRepo == nil { + var cfg *config.Config + if s != nil { + cfg = s.cfg + } + defaultCfg := defaultOpsRuntimeLogConfig(cfg) + return defaultCfg, nil + } + defaultCfg := defaultOpsRuntimeLogConfig(s.cfg) + if ctx == nil { + ctx = context.Background() + } + + raw, err := s.settingRepo.GetValue(ctx, SettingKeyOpsRuntimeLogConfig) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + b, _ := json.Marshal(defaultCfg) + _ = s.settingRepo.Set(ctx, SettingKeyOpsRuntimeLogConfig, string(b)) + return defaultCfg, nil + } + return nil, err + } + + cfg := &OpsRuntimeLogConfig{} + if err := json.Unmarshal([]byte(raw), cfg); err != nil { + return defaultCfg, nil + } + normalizeOpsRuntimeLogConfig(cfg, defaultCfg) + return cfg, nil +} + +func (s *OpsService) UpdateRuntimeLogConfig(ctx context.Context, req *OpsRuntimeLogConfig, operatorID int64) (*OpsRuntimeLogConfig, error) { + if s == nil || s.settingRepo == nil { + return nil, errors.New("setting repository not initialized") + } + if req == nil { + return nil, errors.New("invalid config") + } + if ctx == nil { + ctx = context.Background() + } + if operatorID <= 0 { + return nil, errors.New("invalid operator id") + } + + oldCfg, err := s.GetRuntimeLogConfig(ctx) + if err != nil { + return nil, err + } + next := *req + normalizeOpsRuntimeLogConfig(&next, defaultOpsRuntimeLogConfig(s.cfg)) + if err := validateOpsRuntimeLogConfig(&next); err != nil { + s.auditRuntimeLogConfigFailure(operatorID, oldCfg, &next, "validation_failed: "+err.Error()) + return nil, err + } + + if err := applyOpsRuntimeLogConfig(&next); err != nil { + s.auditRuntimeLogConfigFailure(operatorID, oldCfg, &next, "apply_failed: "+err.Error()) + return nil, err + } + + next.Source = "runtime_setting" + next.UpdatedAt = time.Now().UTC().Format(time.RFC3339Nano) + next.UpdatedByUserID = operatorID + + encoded, err := json.Marshal(&next) + if err != nil { + return nil, err + } + if err := s.settingRepo.Set(ctx, SettingKeyOpsRuntimeLogConfig, string(encoded)); err != nil { + // 存储失败时回滚到旧配置,避免内存状态与持久化状态不一致。 + _ = applyOpsRuntimeLogConfig(oldCfg) + s.auditRuntimeLogConfigFailure(operatorID, oldCfg, &next, "persist_failed: "+err.Error()) + return nil, err + } + + s.auditRuntimeLogConfigChange(operatorID, oldCfg, &next, "updated") + + return &next, nil +} + +func (s *OpsService) ResetRuntimeLogConfig(ctx context.Context, operatorID int64) (*OpsRuntimeLogConfig, error) { + if s == nil || s.settingRepo == nil { + return nil, errors.New("setting repository not initialized") + } + if ctx == nil { + ctx = context.Background() + } + if operatorID <= 0 { + return nil, errors.New("invalid operator id") + } + + oldCfg, err := s.GetRuntimeLogConfig(ctx) + if err != nil { + return nil, err + } + + resetCfg := defaultOpsRuntimeLogConfig(s.cfg) + normalizeOpsRuntimeLogConfig(resetCfg, defaultOpsRuntimeLogConfig(s.cfg)) + if err := validateOpsRuntimeLogConfig(resetCfg); err != nil { + s.auditRuntimeLogConfigFailure(operatorID, oldCfg, resetCfg, "reset_validation_failed: "+err.Error()) + return nil, err + } + if err := applyOpsRuntimeLogConfig(resetCfg); err != nil { + s.auditRuntimeLogConfigFailure(operatorID, oldCfg, resetCfg, "reset_apply_failed: "+err.Error()) + return nil, err + } + + // 清理 runtime 覆盖配置,回退到 env/yaml baseline。 + if err := s.settingRepo.Delete(ctx, SettingKeyOpsRuntimeLogConfig); err != nil && !errors.Is(err, ErrSettingNotFound) { + _ = applyOpsRuntimeLogConfig(oldCfg) + s.auditRuntimeLogConfigFailure(operatorID, oldCfg, resetCfg, "reset_persist_failed: "+err.Error()) + return nil, err + } + + now := time.Now().UTC().Format(time.RFC3339Nano) + resetCfg.Source = "baseline" + resetCfg.UpdatedAt = now + resetCfg.UpdatedByUserID = operatorID + + s.auditRuntimeLogConfigChange(operatorID, oldCfg, resetCfg, "reset") + return resetCfg, nil +} + +func applyOpsRuntimeLogConfig(cfg *OpsRuntimeLogConfig) error { + if cfg == nil { + return fmt.Errorf("nil runtime log config") + } + if err := logger.Reconfigure(func(opts *logger.InitOptions) error { + opts.Level = strings.ToLower(strings.TrimSpace(cfg.Level)) + opts.Caller = cfg.Caller + opts.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.StacktraceLevel)) + opts.Sampling.Enabled = cfg.EnableSampling + opts.Sampling.Initial = cfg.SamplingInitial + opts.Sampling.Thereafter = cfg.SamplingNext + return nil + }); err != nil { + return err + } + return nil +} + +func (s *OpsService) applyRuntimeLogConfigOnStartup(ctx context.Context) { + if s == nil { + return + } + cfg, err := s.GetRuntimeLogConfig(ctx) + if err != nil { + return + } + _ = applyOpsRuntimeLogConfig(cfg) +} + +func (s *OpsService) auditRuntimeLogConfigChange(operatorID int64, oldCfg *OpsRuntimeLogConfig, newCfg *OpsRuntimeLogConfig, action string) { + oldRaw, _ := json.Marshal(oldCfg) + newRaw, _ := json.Marshal(newCfg) + logger.With( + zap.String("component", "audit.log_config_change"), + zap.String("action", strings.TrimSpace(action)), + zap.Int64("operator_id", operatorID), + zap.String("old", string(oldRaw)), + zap.String("new", string(newRaw)), + ).Info("runtime log config changed") +} + +func (s *OpsService) auditRuntimeLogConfigFailure(operatorID int64, oldCfg *OpsRuntimeLogConfig, newCfg *OpsRuntimeLogConfig, reason string) { + oldRaw, _ := json.Marshal(oldCfg) + newRaw, _ := json.Marshal(newCfg) + logger.With( + zap.String("component", "audit.log_config_change"), + zap.String("action", "failed"), + zap.Int64("operator_id", operatorID), + zap.String("reason", strings.TrimSpace(reason)), + zap.String("old", string(oldRaw)), + zap.String("new", string(newRaw)), + ).Warn("runtime log config change failed") +} diff --git a/backend/internal/service/ops_log_runtime_test.go b/backend/internal/service/ops_log_runtime_test.go new file mode 100644 index 00000000..658b4812 --- /dev/null +++ b/backend/internal/service/ops_log_runtime_test.go @@ -0,0 +1,570 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +type runtimeSettingRepoStub struct { + values map[string]string + deleted map[string]bool + setCalls int + getValueFn func(key string) (string, error) + setFn func(key, value string) error + deleteFn func(key string) error +} + +func newRuntimeSettingRepoStub() *runtimeSettingRepoStub { + return &runtimeSettingRepoStub{ + values: map[string]string{}, + deleted: map[string]bool{}, + } +} + +func (s *runtimeSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + value, err := s.GetValue(ctx, key) + if err != nil { + return nil, err + } + return &Setting{Key: key, Value: value}, nil +} + +func (s *runtimeSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + if s.getValueFn != nil { + return s.getValueFn(key) + } + value, ok := s.values[key] + if !ok { + return "", ErrSettingNotFound + } + return value, nil +} + +func (s *runtimeSettingRepoStub) Set(_ context.Context, key, value string) error { + if s.setFn != nil { + if err := s.setFn(key, value); err != nil { + return err + } + } + s.values[key] = value + s.setCalls++ + return nil +} + +func (s *runtimeSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *runtimeSettingRepoStub) SetMultiple(_ context.Context, settings map[string]string) error { + for key, value := range settings { + s.values[key] = value + } + return nil +} + +func (s *runtimeSettingRepoStub) GetAll(_ context.Context) (map[string]string, error) { + out := make(map[string]string, len(s.values)) + for key, value := range s.values { + out[key] = value + } + return out, nil +} + +func (s *runtimeSettingRepoStub) Delete(_ context.Context, key string) error { + if s.deleteFn != nil { + if err := s.deleteFn(key); err != nil { + return err + } + } + if _, ok := s.values[key]; !ok { + return ErrSettingNotFound + } + delete(s.values, key) + s.deleted[key] = true + return nil +} + +func TestUpdateRuntimeLogConfig_InvalidConfigShouldNotApply(t *testing.T) { + repo := newRuntimeSettingRepoStub() + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "info", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + }, + } + + if err := logger.Init(logger.InitOptions{ + Level: "info", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: true, + ToFile: false, + }, + }); err != nil { + t.Fatalf("init logger: %v", err) + } + + _, err := svc.UpdateRuntimeLogConfig(context.Background(), &OpsRuntimeLogConfig{ + Level: "trace", + EnableSampling: true, + SamplingInitial: 100, + SamplingNext: 100, + Caller: true, + StacktraceLevel: "error", + RetentionDays: 30, + }, 1) + if err == nil { + t.Fatalf("expected validation error") + } + if logger.CurrentLevel() != "info" { + t.Fatalf("logger level changed unexpectedly: %s", logger.CurrentLevel()) + } + if repo.setCalls != 1 { + // GetRuntimeLogConfig() 会在 key 缺失时写入默认值,此处应只有这一次持久化。 + t.Fatalf("unexpected set calls: %d", repo.setCalls) + } +} + +func TestResetRuntimeLogConfig_ShouldFallbackToBaseline(t *testing.T) { + repo := newRuntimeSettingRepoStub() + existing := &OpsRuntimeLogConfig{ + Level: "debug", + EnableSampling: true, + SamplingInitial: 50, + SamplingNext: 50, + Caller: true, + StacktraceLevel: "error", + RetentionDays: 60, + Source: "runtime_setting", + } + raw, _ := json.Marshal(existing) + repo.values[SettingKeyOpsRuntimeLogConfig] = string(raw) + + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "warn", + Caller: false, + StacktraceLevel: "fatal", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + Ops: config.OpsConfig{ + Cleanup: config.OpsCleanupConfig{ + ErrorLogRetentionDays: 45, + }, + }, + }, + } + + if err := logger.Init(logger.InitOptions{ + Level: "debug", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: true, + ToFile: false, + }, + }); err != nil { + t.Fatalf("init logger: %v", err) + } + + resetCfg, err := svc.ResetRuntimeLogConfig(context.Background(), 9) + if err != nil { + t.Fatalf("ResetRuntimeLogConfig() error: %v", err) + } + if resetCfg.Source != "baseline" { + t.Fatalf("source = %q, want baseline", resetCfg.Source) + } + if resetCfg.Level != "warn" { + t.Fatalf("level = %q, want warn", resetCfg.Level) + } + if resetCfg.RetentionDays != 45 { + t.Fatalf("retention_days = %d, want 45", resetCfg.RetentionDays) + } + if logger.CurrentLevel() != "warn" { + t.Fatalf("logger level = %q, want warn", logger.CurrentLevel()) + } + if !repo.deleted[SettingKeyOpsRuntimeLogConfig] { + t.Fatalf("runtime setting key should be deleted") + } +} + +func TestResetRuntimeLogConfig_InvalidOperator(t *testing.T) { + svc := &OpsService{settingRepo: newRuntimeSettingRepoStub()} + _, err := svc.ResetRuntimeLogConfig(context.Background(), 0) + if err == nil { + t.Fatalf("expected invalid operator error") + } + if err.Error() != "invalid operator id" { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestGetRuntimeLogConfig_InvalidJSONFallback(t *testing.T) { + repo := newRuntimeSettingRepoStub() + repo.values[SettingKeyOpsRuntimeLogConfig] = `{invalid-json}` + + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "warn", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + }, + } + + got, err := svc.GetRuntimeLogConfig(context.Background()) + if err != nil { + t.Fatalf("GetRuntimeLogConfig() error: %v", err) + } + if got.Level != "warn" { + t.Fatalf("level = %q, want warn", got.Level) + } +} + +func TestUpdateRuntimeLogConfig_PersistFailureRollback(t *testing.T) { + repo := newRuntimeSettingRepoStub() + oldCfg := &OpsRuntimeLogConfig{ + Level: "info", + EnableSampling: false, + SamplingInitial: 100, + SamplingNext: 100, + Caller: true, + StacktraceLevel: "error", + RetentionDays: 30, + } + raw, _ := json.Marshal(oldCfg) + repo.values[SettingKeyOpsRuntimeLogConfig] = string(raw) + repo.setFn = func(key, value string) error { + if key == SettingKeyOpsRuntimeLogConfig { + return errors.New("db down") + } + return nil + } + + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "info", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + }, + } + + if err := logger.Init(logger.InitOptions{ + Level: "info", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: true, + ToFile: false, + }, + }); err != nil { + t.Fatalf("init logger: %v", err) + } + + _, err := svc.UpdateRuntimeLogConfig(context.Background(), &OpsRuntimeLogConfig{ + Level: "debug", + EnableSampling: false, + SamplingInitial: 100, + SamplingNext: 100, + Caller: true, + StacktraceLevel: "error", + RetentionDays: 30, + }, 5) + if err == nil { + t.Fatalf("expected persist error") + } + // Persist failure should rollback runtime level back to old effective level. + if logger.CurrentLevel() != "info" { + t.Fatalf("logger level should rollback to info, got %s", logger.CurrentLevel()) + } +} + +func TestApplyRuntimeLogConfigOnStartup(t *testing.T) { + repo := newRuntimeSettingRepoStub() + cfgRaw := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}` + repo.values[SettingKeyOpsRuntimeLogConfig] = cfgRaw + + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "info", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + }, + } + + if err := logger.Init(logger.InitOptions{ + Level: "info", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: true, + ToFile: false, + }, + }); err != nil { + t.Fatalf("init logger: %v", err) + } + + svc.applyRuntimeLogConfigOnStartup(context.Background()) + if logger.CurrentLevel() != "debug" { + t.Fatalf("expected startup apply debug, got %s", logger.CurrentLevel()) + } +} + +func TestDefaultNormalizeAndValidateRuntimeLogConfig(t *testing.T) { + defaults := defaultOpsRuntimeLogConfig(&config.Config{ + Log: config.LogConfig{ + Level: "DEBUG", + Caller: false, + StacktraceLevel: "FATAL", + Sampling: config.LogSamplingConfig{ + Enabled: true, + Initial: 50, + Thereafter: 20, + }, + }, + Ops: config.OpsConfig{ + Cleanup: config.OpsCleanupConfig{ + ErrorLogRetentionDays: 7, + }, + }, + }) + if defaults.Level != "debug" || defaults.StacktraceLevel != "fatal" || defaults.RetentionDays != 7 { + t.Fatalf("unexpected defaults: %+v", defaults) + } + + cfg := &OpsRuntimeLogConfig{ + Level: " ", + EnableSampling: true, + SamplingInitial: 0, + SamplingNext: -1, + Caller: true, + StacktraceLevel: "", + RetentionDays: 0, + } + normalizeOpsRuntimeLogConfig(cfg, defaults) + if cfg.Level != "debug" || cfg.StacktraceLevel != "fatal" { + t.Fatalf("normalize level/stacktrace failed: %+v", cfg) + } + if cfg.SamplingInitial != 50 || cfg.SamplingNext != 20 || cfg.RetentionDays != 7 { + t.Fatalf("normalize numeric defaults failed: %+v", cfg) + } + if err := validateOpsRuntimeLogConfig(cfg); err != nil { + t.Fatalf("validate normalized config should pass: %v", err) + } +} + +func TestValidateRuntimeLogConfigErrors(t *testing.T) { + cases := []struct { + name string + cfg *OpsRuntimeLogConfig + }{ + {name: "nil", cfg: nil}, + {name: "bad level", cfg: &OpsRuntimeLogConfig{Level: "trace", StacktraceLevel: "error", SamplingInitial: 1, SamplingNext: 1, RetentionDays: 1}}, + {name: "bad stack", cfg: &OpsRuntimeLogConfig{Level: "info", StacktraceLevel: "warn", SamplingInitial: 1, SamplingNext: 1, RetentionDays: 1}}, + {name: "bad initial", cfg: &OpsRuntimeLogConfig{Level: "info", StacktraceLevel: "error", SamplingInitial: 0, SamplingNext: 1, RetentionDays: 1}}, + {name: "bad next", cfg: &OpsRuntimeLogConfig{Level: "info", StacktraceLevel: "error", SamplingInitial: 1, SamplingNext: 0, RetentionDays: 1}}, + {name: "bad retention", cfg: &OpsRuntimeLogConfig{Level: "info", StacktraceLevel: "error", SamplingInitial: 1, SamplingNext: 1, RetentionDays: 0}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if err := validateOpsRuntimeLogConfig(tc.cfg); err == nil { + t.Fatalf("expected validation error") + } + }) + } +} + +func TestGetRuntimeLogConfigFallbackAndErrors(t *testing.T) { + var nilSvc *OpsService + cfg, err := nilSvc.GetRuntimeLogConfig(context.Background()) + if err != nil { + t.Fatalf("nil svc should fallback default: %v", err) + } + if cfg.Level != "info" { + t.Fatalf("unexpected nil svc default level: %s", cfg.Level) + } + + repo := newRuntimeSettingRepoStub() + repo.getValueFn = func(key string) (string, error) { + return "", errors.New("boom") + } + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "warn", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + }, + } + if _, err := svc.GetRuntimeLogConfig(context.Background()); err == nil { + t.Fatalf("expected get value error") + } +} + +func TestUpdateRuntimeLogConfig_PreconditionErrors(t *testing.T) { + svc := &OpsService{} + if _, err := svc.UpdateRuntimeLogConfig(context.Background(), &OpsRuntimeLogConfig{}, 1); err == nil { + t.Fatalf("expected setting repo not initialized") + } + + svc = &OpsService{settingRepo: newRuntimeSettingRepoStub()} + if _, err := svc.UpdateRuntimeLogConfig(context.Background(), nil, 1); err == nil { + t.Fatalf("expected invalid config") + } + if _, err := svc.UpdateRuntimeLogConfig(context.Background(), &OpsRuntimeLogConfig{ + Level: "info", + StacktraceLevel: "error", + SamplingInitial: 1, + SamplingNext: 1, + RetentionDays: 1, + }, 0); err == nil { + t.Fatalf("expected invalid operator") + } +} + +func TestUpdateRuntimeLogConfig_Success(t *testing.T) { + repo := newRuntimeSettingRepoStub() + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "info", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + }, + } + + if err := logger.Init(logger.InitOptions{ + Level: "info", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: true, + ToFile: false, + }, + }); err != nil { + t.Fatalf("init logger: %v", err) + } + + next, err := svc.UpdateRuntimeLogConfig(context.Background(), &OpsRuntimeLogConfig{ + Level: "debug", + EnableSampling: false, + SamplingInitial: 100, + SamplingNext: 100, + Caller: true, + StacktraceLevel: "error", + RetentionDays: 30, + }, 2) + if err != nil { + t.Fatalf("UpdateRuntimeLogConfig() error: %v", err) + } + if next.Source != "runtime_setting" || next.UpdatedByUserID != 2 || next.UpdatedAt == "" { + t.Fatalf("unexpected metadata: %+v", next) + } + if logger.CurrentLevel() != "debug" { + t.Fatalf("expected applied level debug, got %s", logger.CurrentLevel()) + } +} + +func TestResetRuntimeLogConfig_IgnoreNotFoundDelete(t *testing.T) { + repo := newRuntimeSettingRepoStub() + repo.deleteFn = func(key string) error { return ErrSettingNotFound } + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "info", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + }, + } + if _, err := svc.ResetRuntimeLogConfig(context.Background(), 1); err != nil { + t.Fatalf("reset should ignore ErrSettingNotFound: %v", err) + } +} + +func TestApplyRuntimeLogConfigHelpers(t *testing.T) { + if err := applyOpsRuntimeLogConfig(nil); err == nil { + t.Fatalf("expected nil config error") + } + + normalizeOpsRuntimeLogConfig(nil, &OpsRuntimeLogConfig{Level: "info"}) + normalizeOpsRuntimeLogConfig(&OpsRuntimeLogConfig{Level: "debug"}, nil) + + var nilSvc *OpsService + nilSvc.applyRuntimeLogConfigOnStartup(context.Background()) +} diff --git a/backend/internal/service/ops_models.go b/backend/internal/service/ops_models.go index 347cd52b..2ed06d90 100644 --- a/backend/internal/service/ops_models.go +++ b/backend/internal/service/ops_models.go @@ -2,6 +2,21 @@ package service import "time" +type OpsSystemLog struct { + ID int64 `json:"id"` + CreatedAt time.Time `json:"created_at"` + Level string `json:"level"` + Component string `json:"component"` + Message string `json:"message"` + RequestID string `json:"request_id"` + ClientRequestID string `json:"client_request_id"` + UserID *int64 `json:"user_id"` + AccountID *int64 `json:"account_id"` + Platform string `json:"platform"` + Model string `json:"model"` + Extra map[string]any `json:"extra,omitempty"` +} + type OpsErrorLog struct { ID int64 `json:"id"` CreatedAt time.Time `json:"created_at"` diff --git a/backend/internal/service/ops_openai_token_stats.go b/backend/internal/service/ops_openai_token_stats.go new file mode 100644 index 00000000..63f88ba0 --- /dev/null +++ b/backend/internal/service/ops_openai_token_stats.go @@ -0,0 +1,55 @@ +package service + +import ( + "context" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +func (s *OpsService) GetOpenAITokenStats(ctx context.Context, filter *OpsOpenAITokenStatsFilter) (*OpsOpenAITokenStatsResponse, error) { + if err := s.RequireMonitoringEnabled(ctx); err != nil { + return nil, err + } + if s.opsRepo == nil { + return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available") + } + if filter == nil { + return nil, infraerrors.BadRequest("OPS_FILTER_REQUIRED", "filter is required") + } + if filter.StartTime.IsZero() || filter.EndTime.IsZero() { + return nil, infraerrors.BadRequest("OPS_TIME_RANGE_REQUIRED", "start_time/end_time are required") + } + if filter.StartTime.After(filter.EndTime) { + return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time") + } + + if filter.GroupID != nil && *filter.GroupID <= 0 { + return nil, infraerrors.BadRequest("OPS_GROUP_ID_INVALID", "group_id must be > 0") + } + + // top_n cannot be mixed with page/page_size params. + if filter.TopN > 0 && (filter.Page > 0 || filter.PageSize > 0) { + return nil, infraerrors.BadRequest("OPS_PAGINATION_CONFLICT", "top_n cannot be used with page/page_size") + } + + if filter.TopN > 0 { + if filter.TopN < 1 || filter.TopN > 100 { + return nil, infraerrors.BadRequest("OPS_TOPN_INVALID", "top_n must be between 1 and 100") + } + } else { + if filter.Page <= 0 { + filter.Page = 1 + } + if filter.PageSize <= 0 { + filter.PageSize = 20 + } + if filter.Page < 1 { + return nil, infraerrors.BadRequest("OPS_PAGE_INVALID", "page must be >= 1") + } + if filter.PageSize < 1 || filter.PageSize > 100 { + return nil, infraerrors.BadRequest("OPS_PAGE_SIZE_INVALID", "page_size must be between 1 and 100") + } + } + + return s.opsRepo.GetOpenAITokenStats(ctx, filter) +} diff --git a/backend/internal/service/ops_openai_token_stats_models.go b/backend/internal/service/ops_openai_token_stats_models.go new file mode 100644 index 00000000..ef40fa1f --- /dev/null +++ b/backend/internal/service/ops_openai_token_stats_models.go @@ -0,0 +1,54 @@ +package service + +import "time" + +type OpsOpenAITokenStatsFilter struct { + TimeRange string + StartTime time.Time + EndTime time.Time + + Platform string + GroupID *int64 + + // Pagination mode (default): page/page_size + Page int + PageSize int + + // TopN mode: top_n + TopN int +} + +func (f *OpsOpenAITokenStatsFilter) IsTopNMode() bool { + return f != nil && f.TopN > 0 +} + +type OpsOpenAITokenStatsItem struct { + Model string `json:"model"` + RequestCount int64 `json:"request_count"` + AvgTokensPerSec *float64 `json:"avg_tokens_per_sec"` + AvgFirstTokenMs *float64 `json:"avg_first_token_ms"` + TotalOutputTokens int64 `json:"total_output_tokens"` + AvgDurationMs int64 `json:"avg_duration_ms"` + RequestsWithFirstToken int64 `json:"requests_with_first_token"` +} + +type OpsOpenAITokenStatsResponse struct { + TimeRange string `json:"time_range"` + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time"` + + Platform string `json:"platform,omitempty"` + GroupID *int64 `json:"group_id,omitempty"` + + Items []*OpsOpenAITokenStatsItem `json:"items"` + + // Total model rows before pagination/topN trimming. + Total int64 `json:"total"` + + // Pagination mode metadata. + Page int `json:"page,omitempty"` + PageSize int `json:"page_size,omitempty"` + + // TopN mode metadata. + TopN *int `json:"top_n,omitempty"` +} diff --git a/backend/internal/service/ops_openai_token_stats_test.go b/backend/internal/service/ops_openai_token_stats_test.go new file mode 100644 index 00000000..ee332f91 --- /dev/null +++ b/backend/internal/service/ops_openai_token_stats_test.go @@ -0,0 +1,162 @@ +package service + +import ( + "context" + "testing" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +type openAITokenStatsRepoStub struct { + OpsRepository + resp *OpsOpenAITokenStatsResponse + err error + captured *OpsOpenAITokenStatsFilter +} + +func (s *openAITokenStatsRepoStub) GetOpenAITokenStats(ctx context.Context, filter *OpsOpenAITokenStatsFilter) (*OpsOpenAITokenStatsResponse, error) { + s.captured = filter + if s.err != nil { + return nil, s.err + } + if s.resp != nil { + return s.resp, nil + } + return &OpsOpenAITokenStatsResponse{}, nil +} + +func TestOpsServiceGetOpenAITokenStats_Validation(t *testing.T) { + now := time.Now().UTC() + + tests := []struct { + name string + filter *OpsOpenAITokenStatsFilter + wantCode int + wantReason string + }{ + { + name: "filter 不能为空", + filter: nil, + wantCode: 400, + wantReason: "OPS_FILTER_REQUIRED", + }, + { + name: "start_time/end_time 必填", + filter: &OpsOpenAITokenStatsFilter{ + StartTime: time.Time{}, + EndTime: now, + }, + wantCode: 400, + wantReason: "OPS_TIME_RANGE_REQUIRED", + }, + { + name: "start_time 不能晚于 end_time", + filter: &OpsOpenAITokenStatsFilter{ + StartTime: now, + EndTime: now.Add(-1 * time.Minute), + }, + wantCode: 400, + wantReason: "OPS_TIME_RANGE_INVALID", + }, + { + name: "group_id 必须大于 0", + filter: &OpsOpenAITokenStatsFilter{ + StartTime: now.Add(-time.Hour), + EndTime: now, + GroupID: int64Ptr(0), + }, + wantCode: 400, + wantReason: "OPS_GROUP_ID_INVALID", + }, + { + name: "top_n 与分页参数互斥", + filter: &OpsOpenAITokenStatsFilter{ + StartTime: now.Add(-time.Hour), + EndTime: now, + TopN: 10, + Page: 1, + }, + wantCode: 400, + wantReason: "OPS_PAGINATION_CONFLICT", + }, + { + name: "top_n 参数越界", + filter: &OpsOpenAITokenStatsFilter{ + StartTime: now.Add(-time.Hour), + EndTime: now, + TopN: 101, + }, + wantCode: 400, + wantReason: "OPS_TOPN_INVALID", + }, + { + name: "page_size 参数越界", + filter: &OpsOpenAITokenStatsFilter{ + StartTime: now.Add(-time.Hour), + EndTime: now, + Page: 1, + PageSize: 101, + }, + wantCode: 400, + wantReason: "OPS_PAGE_SIZE_INVALID", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc := &OpsService{ + opsRepo: &openAITokenStatsRepoStub{}, + } + + _, err := svc.GetOpenAITokenStats(context.Background(), tt.filter) + require.Error(t, err) + require.Equal(t, tt.wantCode, infraerrors.Code(err)) + require.Equal(t, tt.wantReason, infraerrors.Reason(err)) + }) + } +} + +func TestOpsServiceGetOpenAITokenStats_DefaultPagination(t *testing.T) { + now := time.Now().UTC() + repo := &openAITokenStatsRepoStub{ + resp: &OpsOpenAITokenStatsResponse{ + Items: []*OpsOpenAITokenStatsItem{ + {Model: "gpt-4o-mini", RequestCount: 10}, + }, + Total: 1, + }, + } + svc := &OpsService{opsRepo: repo} + + filter := &OpsOpenAITokenStatsFilter{ + TimeRange: "30d", + StartTime: now.Add(-30 * 24 * time.Hour), + EndTime: now, + } + resp, err := svc.GetOpenAITokenStats(context.Background(), filter) + require.NoError(t, err) + require.NotNil(t, resp) + require.NotNil(t, repo.captured) + require.Equal(t, 1, repo.captured.Page) + require.Equal(t, 20, repo.captured.PageSize) + require.Equal(t, 0, repo.captured.TopN) +} + +func TestOpsServiceGetOpenAITokenStats_RepoUnavailable(t *testing.T) { + now := time.Now().UTC() + svc := &OpsService{} + + _, err := svc.GetOpenAITokenStats(context.Background(), &OpsOpenAITokenStatsFilter{ + TimeRange: "1h", + StartTime: now.Add(-time.Hour), + EndTime: now, + TopN: 10, + }) + require.Error(t, err) + require.Equal(t, 503, infraerrors.Code(err)) + require.Equal(t, "OPS_REPO_UNAVAILABLE", infraerrors.Reason(err)) +} + +func int64Ptr(v int64) *int64 { return &v } diff --git a/backend/internal/service/ops_port.go b/backend/internal/service/ops_port.go index 347b06b5..f3633eae 100644 --- a/backend/internal/service/ops_port.go +++ b/backend/internal/service/ops_port.go @@ -10,6 +10,10 @@ type OpsRepository interface { ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error) ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) ([]*OpsRequestDetail, int64, error) + BatchInsertSystemLogs(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) + ListSystemLogs(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) + DeleteSystemLogs(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) + InsertSystemLogCleanupAudit(ctx context.Context, input *OpsSystemLogCleanupAudit) error InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error) UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error @@ -27,6 +31,7 @@ type OpsRepository interface { GetLatencyHistogram(ctx context.Context, filter *OpsDashboardFilter) (*OpsLatencyHistogramResponse, error) GetErrorTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsErrorTrendResponse, error) GetErrorDistribution(ctx context.Context, filter *OpsDashboardFilter) (*OpsErrorDistributionResponse, error) + GetOpenAITokenStats(ctx context.Context, filter *OpsOpenAITokenStatsFilter) (*OpsOpenAITokenStatsResponse, error) InsertSystemMetrics(ctx context.Context, input *OpsInsertSystemMetricsInput) error GetLatestSystemMetrics(ctx context.Context, windowMinutes int) (*OpsSystemMetricsSnapshot, error) @@ -98,6 +103,10 @@ type OpsInsertErrorLogInput struct { // It is set by OpsService.RecordError before persisting. UpstreamErrorsJSON *string + AuthLatencyMs *int64 + RoutingLatencyMs *int64 + UpstreamLatencyMs *int64 + ResponseLatencyMs *int64 TimeToFirstTokenMs *int64 RequestBodyJSON *string // sanitized json string (not raw bytes) @@ -200,6 +209,69 @@ type OpsInsertSystemMetricsInput struct { ConcurrencyQueueDepth *int } +type OpsInsertSystemLogInput struct { + CreatedAt time.Time + Level string + Component string + Message string + RequestID string + ClientRequestID string + UserID *int64 + AccountID *int64 + Platform string + Model string + ExtraJSON string +} + +type OpsSystemLogFilter struct { + StartTime *time.Time + EndTime *time.Time + + Level string + Component string + + RequestID string + ClientRequestID string + UserID *int64 + AccountID *int64 + Platform string + Model string + Query string + + Page int + PageSize int +} + +type OpsSystemLogCleanupFilter struct { + StartTime *time.Time + EndTime *time.Time + + Level string + Component string + + RequestID string + ClientRequestID string + UserID *int64 + AccountID *int64 + Platform string + Model string + Query string +} + +type OpsSystemLogList struct { + Logs []*OpsSystemLog `json:"logs"` + Total int `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` +} + +type OpsSystemLogCleanupAudit struct { + CreatedAt time.Time + OperatorID int64 + Conditions string + DeletedRows int64 +} + type OpsSystemMetricsSnapshot struct { ID int64 `json:"id"` CreatedAt time.Time `json:"created_at"` diff --git a/backend/internal/service/ops_query_mode.go b/backend/internal/service/ops_query_mode.go index e6fa9c1e..fa97f358 100644 --- a/backend/internal/service/ops_query_mode.go +++ b/backend/internal/service/ops_query_mode.go @@ -38,3 +38,18 @@ func (m OpsQueryMode) IsValid() bool { return false } } + +func shouldFallbackOpsPreagg(filter *OpsDashboardFilter, err error) bool { + return filter != nil && + filter.QueryMode == OpsQueryModeAuto && + errors.Is(err, ErrOpsPreaggregatedNotPopulated) +} + +func cloneOpsFilterWithMode(filter *OpsDashboardFilter, mode OpsQueryMode) *OpsDashboardFilter { + if filter == nil { + return nil + } + cloned := *filter + cloned.QueryMode = mode + return &cloned +} diff --git a/backend/internal/service/ops_query_mode_test.go b/backend/internal/service/ops_query_mode_test.go new file mode 100644 index 00000000..26c4b730 --- /dev/null +++ b/backend/internal/service/ops_query_mode_test.go @@ -0,0 +1,66 @@ +//go:build unit + +package service + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestShouldFallbackOpsPreagg(t *testing.T) { + preaggErr := ErrOpsPreaggregatedNotPopulated + otherErr := errors.New("some other error") + + autoFilter := &OpsDashboardFilter{QueryMode: OpsQueryModeAuto} + rawFilter := &OpsDashboardFilter{QueryMode: OpsQueryModeRaw} + preaggFilter := &OpsDashboardFilter{QueryMode: OpsQueryModePreagg} + + tests := []struct { + name string + filter *OpsDashboardFilter + err error + want bool + }{ + {"auto mode + preagg error => fallback", autoFilter, preaggErr, true}, + {"auto mode + other error => no fallback", autoFilter, otherErr, false}, + {"auto mode + nil error => no fallback", autoFilter, nil, false}, + {"raw mode + preagg error => no fallback", rawFilter, preaggErr, false}, + {"preagg mode + preagg error => no fallback", preaggFilter, preaggErr, false}, + {"nil filter => no fallback", nil, preaggErr, false}, + {"wrapped preagg error => fallback", autoFilter, errors.Join(preaggErr, otherErr), true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := shouldFallbackOpsPreagg(tc.filter, tc.err) + require.Equal(t, tc.want, got) + }) + } +} + +func TestCloneOpsFilterWithMode(t *testing.T) { + t.Run("nil filter returns nil", func(t *testing.T) { + require.Nil(t, cloneOpsFilterWithMode(nil, OpsQueryModeRaw)) + }) + + t.Run("cloned filter has new mode", func(t *testing.T) { + groupID := int64(42) + original := &OpsDashboardFilter{ + StartTime: time.Now(), + EndTime: time.Now().Add(time.Hour), + Platform: "anthropic", + GroupID: &groupID, + QueryMode: OpsQueryModeAuto, + } + + cloned := cloneOpsFilterWithMode(original, OpsQueryModeRaw) + require.Equal(t, OpsQueryModeRaw, cloned.QueryMode) + require.Equal(t, OpsQueryModeAuto, original.QueryMode, "original should not be modified") + require.Equal(t, original.Platform, cloned.Platform) + require.Equal(t, original.StartTime, cloned.StartTime) + require.Equal(t, original.GroupID, cloned.GroupID) + }) +} diff --git a/backend/internal/service/ops_repo_mock_test.go b/backend/internal/service/ops_repo_mock_test.go new file mode 100644 index 00000000..e250dea3 --- /dev/null +++ b/backend/internal/service/ops_repo_mock_test.go @@ -0,0 +1,196 @@ +package service + +import ( + "context" + "time" +) + +// opsRepoMock is a test-only OpsRepository implementation with optional function hooks. +type opsRepoMock struct { + BatchInsertSystemLogsFn func(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) + ListSystemLogsFn func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) + DeleteSystemLogsFn func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) + InsertSystemLogCleanupAuditFn func(ctx context.Context, input *OpsSystemLogCleanupAudit) error +} + +func (m *opsRepoMock) InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) { + return 0, nil +} + +func (m *opsRepoMock) ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) { + return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Page: 1, PageSize: 20}, nil +} + +func (m *opsRepoMock) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error) { + return &OpsErrorLogDetail{}, nil +} + +func (m *opsRepoMock) ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) ([]*OpsRequestDetail, int64, error) { + return []*OpsRequestDetail{}, 0, nil +} + +func (m *opsRepoMock) BatchInsertSystemLogs(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) { + if m.BatchInsertSystemLogsFn != nil { + return m.BatchInsertSystemLogsFn(ctx, inputs) + } + return int64(len(inputs)), nil +} + +func (m *opsRepoMock) ListSystemLogs(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) { + if m.ListSystemLogsFn != nil { + return m.ListSystemLogsFn(ctx, filter) + } + return &OpsSystemLogList{Logs: []*OpsSystemLog{}, Total: 0, Page: 1, PageSize: 50}, nil +} + +func (m *opsRepoMock) DeleteSystemLogs(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) { + if m.DeleteSystemLogsFn != nil { + return m.DeleteSystemLogsFn(ctx, filter) + } + return 0, nil +} + +func (m *opsRepoMock) InsertSystemLogCleanupAudit(ctx context.Context, input *OpsSystemLogCleanupAudit) error { + if m.InsertSystemLogCleanupAuditFn != nil { + return m.InsertSystemLogCleanupAuditFn(ctx, input) + } + return nil +} + +func (m *opsRepoMock) InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error) { + return 0, nil +} + +func (m *opsRepoMock) UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error { + return nil +} + +func (m *opsRepoMock) GetLatestRetryAttemptForError(ctx context.Context, sourceErrorID int64) (*OpsRetryAttempt, error) { + return nil, nil +} + +func (m *opsRepoMock) ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*OpsRetryAttempt, error) { + return []*OpsRetryAttempt{}, nil +} + +func (m *opsRepoMock) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error { + return nil +} + +func (m *opsRepoMock) GetWindowStats(ctx context.Context, filter *OpsDashboardFilter) (*OpsWindowStats, error) { + return &OpsWindowStats{}, nil +} + +func (m *opsRepoMock) GetRealtimeTrafficSummary(ctx context.Context, filter *OpsDashboardFilter) (*OpsRealtimeTrafficSummary, error) { + return &OpsRealtimeTrafficSummary{}, nil +} + +func (m *opsRepoMock) GetDashboardOverview(ctx context.Context, filter *OpsDashboardFilter) (*OpsDashboardOverview, error) { + return &OpsDashboardOverview{}, nil +} + +func (m *opsRepoMock) GetThroughputTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsThroughputTrendResponse, error) { + return &OpsThroughputTrendResponse{}, nil +} + +func (m *opsRepoMock) GetLatencyHistogram(ctx context.Context, filter *OpsDashboardFilter) (*OpsLatencyHistogramResponse, error) { + return &OpsLatencyHistogramResponse{}, nil +} + +func (m *opsRepoMock) GetErrorTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsErrorTrendResponse, error) { + return &OpsErrorTrendResponse{}, nil +} + +func (m *opsRepoMock) GetErrorDistribution(ctx context.Context, filter *OpsDashboardFilter) (*OpsErrorDistributionResponse, error) { + return &OpsErrorDistributionResponse{}, nil +} + +func (m *opsRepoMock) GetOpenAITokenStats(ctx context.Context, filter *OpsOpenAITokenStatsFilter) (*OpsOpenAITokenStatsResponse, error) { + return &OpsOpenAITokenStatsResponse{}, nil +} + +func (m *opsRepoMock) InsertSystemMetrics(ctx context.Context, input *OpsInsertSystemMetricsInput) error { + return nil +} + +func (m *opsRepoMock) GetLatestSystemMetrics(ctx context.Context, windowMinutes int) (*OpsSystemMetricsSnapshot, error) { + return &OpsSystemMetricsSnapshot{}, nil +} + +func (m *opsRepoMock) UpsertJobHeartbeat(ctx context.Context, input *OpsUpsertJobHeartbeatInput) error { + return nil +} + +func (m *opsRepoMock) ListJobHeartbeats(ctx context.Context) ([]*OpsJobHeartbeat, error) { + return []*OpsJobHeartbeat{}, nil +} + +func (m *opsRepoMock) ListAlertRules(ctx context.Context) ([]*OpsAlertRule, error) { + return []*OpsAlertRule{}, nil +} + +func (m *opsRepoMock) CreateAlertRule(ctx context.Context, input *OpsAlertRule) (*OpsAlertRule, error) { + return input, nil +} + +func (m *opsRepoMock) UpdateAlertRule(ctx context.Context, input *OpsAlertRule) (*OpsAlertRule, error) { + return input, nil +} + +func (m *opsRepoMock) DeleteAlertRule(ctx context.Context, id int64) error { + return nil +} + +func (m *opsRepoMock) ListAlertEvents(ctx context.Context, filter *OpsAlertEventFilter) ([]*OpsAlertEvent, error) { + return []*OpsAlertEvent{}, nil +} + +func (m *opsRepoMock) GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error) { + return &OpsAlertEvent{}, nil +} + +func (m *opsRepoMock) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) { + return nil, nil +} + +func (m *opsRepoMock) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) { + return nil, nil +} + +func (m *opsRepoMock) CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) (*OpsAlertEvent, error) { + return event, nil +} + +func (m *opsRepoMock) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error { + return nil +} + +func (m *opsRepoMock) UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error { + return nil +} + +func (m *opsRepoMock) CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error) { + return input, nil +} + +func (m *opsRepoMock) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) { + return false, nil +} + +func (m *opsRepoMock) UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error { + return nil +} + +func (m *opsRepoMock) UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error { + return nil +} + +func (m *opsRepoMock) GetLatestHourlyBucketStart(ctx context.Context) (time.Time, bool, error) { + return time.Time{}, false, nil +} + +func (m *opsRepoMock) GetLatestDailyBucketDate(ctx context.Context) (time.Time, bool, error) { + return time.Time{}, false, nil +} + +var _ OpsRepository = (*opsRepoMock)(nil) diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go index 23a524ad..f0daa3e2 100644 --- a/backend/internal/service/ops_retry.go +++ b/backend/internal/service/ops_retry.go @@ -13,7 +13,6 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/domain" - "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/gin-gonic/gin" "github.com/lib/pq" @@ -480,7 +479,7 @@ func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryReq attemptCtx := ctx if switches > 0 { - attemptCtx = context.WithValue(attemptCtx, ctxkey.AccountSwitchCount, switches) + attemptCtx = WithAccountSwitchCount(attemptCtx, switches, false) } exec := func() *opsRetryExecution { defer selection.ReleaseFunc() @@ -675,6 +674,7 @@ func newOpsRetryContext(ctx context.Context, errorLog *OpsErrorLogDetail) (*gin. } c.Request = req + SetOpenAIClientTransport(c, OpenAIClientTransportHTTP) return c, w } diff --git a/backend/internal/service/ops_retry_context_test.go b/backend/internal/service/ops_retry_context_test.go new file mode 100644 index 00000000..a8c26ee4 --- /dev/null +++ b/backend/internal/service/ops_retry_context_test.go @@ -0,0 +1,47 @@ +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewOpsRetryContext_SetsHTTPTransportAndRequestHeaders(t *testing.T) { + errorLog := &OpsErrorLogDetail{ + OpsErrorLog: OpsErrorLog{ + RequestPath: "/openai/v1/responses", + }, + UserAgent: "ops-retry-agent/1.0", + RequestHeaders: `{ + "anthropic-beta":"beta-v1", + "ANTHROPIC-VERSION":"2023-06-01", + "authorization":"Bearer should-not-forward" + }`, + } + + c, w := newOpsRetryContext(context.Background(), errorLog) + require.NotNil(t, c) + require.NotNil(t, w) + require.NotNil(t, c.Request) + + require.Equal(t, "/openai/v1/responses", c.Request.URL.Path) + require.Equal(t, "application/json", c.Request.Header.Get("Content-Type")) + require.Equal(t, "ops-retry-agent/1.0", c.Request.Header.Get("User-Agent")) + require.Equal(t, "beta-v1", c.Request.Header.Get("anthropic-beta")) + require.Equal(t, "2023-06-01", c.Request.Header.Get("anthropic-version")) + require.Empty(t, c.Request.Header.Get("authorization"), "未在白名单内的敏感头不应被重放") + require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c)) +} + +func TestNewOpsRetryContext_InvalidHeadersJSONStillSetsHTTPTransport(t *testing.T) { + errorLog := &OpsErrorLogDetail{ + RequestHeaders: "{invalid-json", + } + + c, _ := newOpsRetryContext(context.Background(), errorLog) + require.NotNil(t, c) + require.NotNil(t, c.Request) + require.Equal(t, "/", c.Request.URL.Path) + require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c)) +} diff --git a/backend/internal/service/ops_service.go b/backend/internal/service/ops_service.go index 9c121b8b..767d1704 100644 --- a/backend/internal/service/ops_service.go +++ b/backend/internal/service/ops_service.go @@ -20,6 +20,22 @@ const ( opsMaxStoredErrorBodyBytes = 20 * 1024 ) +// PrepareOpsRequestBodyForQueue 在入队前对请求体执行脱敏与裁剪,返回可直接写入 OpsInsertErrorLogInput 的字段。 +// 该方法用于避免异步队列持有大块原始请求体,减少错误风暴下的内存放大风险。 +func PrepareOpsRequestBodyForQueue(raw []byte) (requestBodyJSON *string, truncated bool, requestBodyBytes *int) { + if len(raw) == 0 { + return nil, false, nil + } + sanitized, truncated, bytesLen := sanitizeAndTrimRequestBody(raw, opsMaxStoredRequestBodyBytes) + if sanitized != "" { + out := sanitized + requestBodyJSON = &out + } + n := bytesLen + requestBodyBytes = &n + return requestBodyJSON, truncated, requestBodyBytes +} + // OpsService provides ingestion and query APIs for the Ops monitoring module. type OpsService struct { opsRepo OpsRepository @@ -37,6 +53,7 @@ type OpsService struct { openAIGatewayService *OpenAIGatewayService geminiCompatService *GeminiMessagesCompatService antigravityGatewayService *AntigravityGatewayService + systemLogSink *OpsSystemLogSink } func NewOpsService( @@ -50,8 +67,9 @@ func NewOpsService( openAIGatewayService *OpenAIGatewayService, geminiCompatService *GeminiMessagesCompatService, antigravityGatewayService *AntigravityGatewayService, + systemLogSink *OpsSystemLogSink, ) *OpsService { - return &OpsService{ + svc := &OpsService{ opsRepo: opsRepo, settingRepo: settingRepo, cfg: cfg, @@ -64,7 +82,10 @@ func NewOpsService( openAIGatewayService: openAIGatewayService, geminiCompatService: geminiCompatService, antigravityGatewayService: antigravityGatewayService, + systemLogSink: systemLogSink, } + svc.applyRuntimeLogConfigOnStartup(context.Background()) + return svc } func (s *OpsService) RequireMonitoringEnabled(ctx context.Context) error { @@ -127,12 +148,7 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn // Sanitize + trim request body (errors only). if len(rawRequestBody) > 0 { - sanitized, truncated, bytesLen := sanitizeAndTrimRequestBody(rawRequestBody, opsMaxStoredRequestBodyBytes) - if sanitized != "" { - entry.RequestBodyJSON = &sanitized - } - entry.RequestBodyTruncated = truncated - entry.RequestBodyBytes = &bytesLen + entry.RequestBodyJSON, entry.RequestBodyTruncated, entry.RequestBodyBytes = PrepareOpsRequestBodyForQueue(rawRequestBody) } // Sanitize + truncate error_body to avoid storing sensitive data. diff --git a/backend/internal/service/ops_service_prepare_queue_test.go b/backend/internal/service/ops_service_prepare_queue_test.go new file mode 100644 index 00000000..d6f32c2d --- /dev/null +++ b/backend/internal/service/ops_service_prepare_queue_test.go @@ -0,0 +1,60 @@ +package service + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPrepareOpsRequestBodyForQueue_EmptyBody(t *testing.T) { + requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(nil) + require.Nil(t, requestBodyJSON) + require.False(t, truncated) + require.Nil(t, requestBodyBytes) +} + +func TestPrepareOpsRequestBodyForQueue_InvalidJSON(t *testing.T) { + raw := []byte("{invalid-json") + requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(raw) + require.Nil(t, requestBodyJSON) + require.False(t, truncated) + require.NotNil(t, requestBodyBytes) + require.Equal(t, len(raw), *requestBodyBytes) +} + +func TestPrepareOpsRequestBodyForQueue_RedactSensitiveFields(t *testing.T) { + raw := []byte(`{ + "model":"claude-3-5-sonnet-20241022", + "api_key":"sk-test-123", + "headers":{"authorization":"Bearer secret-token"}, + "messages":[{"role":"user","content":"hello"}] + }`) + + requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(raw) + require.NotNil(t, requestBodyJSON) + require.NotNil(t, requestBodyBytes) + require.False(t, truncated) + require.Equal(t, len(raw), *requestBodyBytes) + + var body map[string]any + require.NoError(t, json.Unmarshal([]byte(*requestBodyJSON), &body)) + require.Equal(t, "[REDACTED]", body["api_key"]) + headers, ok := body["headers"].(map[string]any) + require.True(t, ok) + require.Equal(t, "[REDACTED]", headers["authorization"]) +} + +func TestPrepareOpsRequestBodyForQueue_LargeBodyTruncated(t *testing.T) { + largeMsg := strings.Repeat("x", opsMaxStoredRequestBodyBytes*2) + raw := []byte(`{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":"` + largeMsg + `"}]}`) + + requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(raw) + require.NotNil(t, requestBodyJSON) + require.NotNil(t, requestBodyBytes) + require.True(t, truncated) + require.Equal(t, len(raw), *requestBodyBytes) + require.LessOrEqual(t, len(*requestBodyJSON), opsMaxStoredRequestBodyBytes) + require.Contains(t, *requestBodyJSON, "request_body_truncated") +} diff --git a/backend/internal/service/ops_settings.go b/backend/internal/service/ops_settings.go index a6a4a0d7..7514cc80 100644 --- a/backend/internal/service/ops_settings.go +++ b/backend/internal/service/ops_settings.go @@ -368,7 +368,7 @@ func defaultOpsAdvancedSettings() *OpsAdvancedSettings { Aggregation: OpsAggregationSettings{ AggregationEnabled: false, }, - IgnoreCountTokensErrors: false, + IgnoreCountTokensErrors: true, // count_tokens 404 是预期行为,默认忽略 IgnoreContextCanceled: true, // Default to true - client disconnects are not errors IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue AutoRefreshEnabled: false, diff --git a/backend/internal/service/ops_settings_models.go b/backend/internal/service/ops_settings_models.go index ecc62220..8b5359e3 100644 --- a/backend/internal/service/ops_settings_models.go +++ b/backend/internal/service/ops_settings_models.go @@ -68,6 +68,20 @@ type OpsMetricThresholds struct { UpstreamErrorRatePercentMax *float64 `json:"upstream_error_rate_percent_max,omitempty"` // 上游错误率高于此值变红 } +type OpsRuntimeLogConfig struct { + Level string `json:"level"` + EnableSampling bool `json:"enable_sampling"` + SamplingInitial int `json:"sampling_initial"` + SamplingNext int `json:"sampling_thereafter"` + Caller bool `json:"caller"` + StacktraceLevel string `json:"stacktrace_level"` + RetentionDays int `json:"retention_days"` + Source string `json:"source,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` + UpdatedByUserID int64 `json:"updated_by_user_id,omitempty"` + Extra map[string]any `json:"extra,omitempty"` +} + type OpsAlertRuntimeSettings struct { EvaluationIntervalSeconds int `json:"evaluation_interval_seconds"` diff --git a/backend/internal/service/ops_system_log_service.go b/backend/internal/service/ops_system_log_service.go new file mode 100644 index 00000000..f5a64803 --- /dev/null +++ b/backend/internal/service/ops_system_log_service.go @@ -0,0 +1,124 @@ +package service + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "log" + "strings" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +func (s *OpsService) ListSystemLogs(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) { + if err := s.RequireMonitoringEnabled(ctx); err != nil { + return nil, err + } + if s.opsRepo == nil { + return &OpsSystemLogList{ + Logs: []*OpsSystemLog{}, + Total: 0, + Page: 1, + PageSize: 50, + }, nil + } + if filter == nil { + filter = &OpsSystemLogFilter{} + } + if filter.Page <= 0 { + filter.Page = 1 + } + if filter.PageSize <= 0 { + filter.PageSize = 50 + } + if filter.PageSize > 200 { + filter.PageSize = 200 + } + + result, err := s.opsRepo.ListSystemLogs(ctx, filter) + if err != nil { + return nil, infraerrors.InternalServer("OPS_SYSTEM_LOG_LIST_FAILED", "Failed to list system logs").WithCause(err) + } + return result, nil +} + +func (s *OpsService) CleanupSystemLogs(ctx context.Context, filter *OpsSystemLogCleanupFilter, operatorID int64) (int64, error) { + if err := s.RequireMonitoringEnabled(ctx); err != nil { + return 0, err + } + if s.opsRepo == nil { + return 0, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available") + } + if operatorID <= 0 { + return 0, infraerrors.BadRequest("OPS_SYSTEM_LOG_CLEANUP_INVALID_OPERATOR", "invalid operator") + } + if filter == nil { + filter = &OpsSystemLogCleanupFilter{} + } + if filter.EndTime != nil && filter.StartTime != nil && filter.StartTime.After(*filter.EndTime) { + return 0, infraerrors.BadRequest("OPS_SYSTEM_LOG_CLEANUP_INVALID_RANGE", "invalid time range") + } + + deletedRows, err := s.opsRepo.DeleteSystemLogs(ctx, filter) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, nil + } + if strings.Contains(strings.ToLower(err.Error()), "requires at least one filter") { + return 0, infraerrors.BadRequest("OPS_SYSTEM_LOG_CLEANUP_FILTER_REQUIRED", "cleanup requires at least one filter condition") + } + return 0, infraerrors.InternalServer("OPS_SYSTEM_LOG_CLEANUP_FAILED", "Failed to cleanup system logs").WithCause(err) + } + + if auditErr := s.opsRepo.InsertSystemLogCleanupAudit(ctx, &OpsSystemLogCleanupAudit{ + CreatedAt: time.Now().UTC(), + OperatorID: operatorID, + Conditions: marshalSystemLogCleanupConditions(filter), + DeletedRows: deletedRows, + }); auditErr != nil { + // 审计失败不影响主流程,避免运维清理被阻塞。 + log.Printf("[OpsSystemLog] cleanup audit failed: %v", auditErr) + } + return deletedRows, nil +} + +func marshalSystemLogCleanupConditions(filter *OpsSystemLogCleanupFilter) string { + if filter == nil { + return "{}" + } + payload := map[string]any{ + "level": strings.TrimSpace(filter.Level), + "component": strings.TrimSpace(filter.Component), + "request_id": strings.TrimSpace(filter.RequestID), + "client_request_id": strings.TrimSpace(filter.ClientRequestID), + "platform": strings.TrimSpace(filter.Platform), + "model": strings.TrimSpace(filter.Model), + "query": strings.TrimSpace(filter.Query), + } + if filter.UserID != nil { + payload["user_id"] = *filter.UserID + } + if filter.AccountID != nil { + payload["account_id"] = *filter.AccountID + } + if filter.StartTime != nil && !filter.StartTime.IsZero() { + payload["start_time"] = filter.StartTime.UTC().Format(time.RFC3339Nano) + } + if filter.EndTime != nil && !filter.EndTime.IsZero() { + payload["end_time"] = filter.EndTime.UTC().Format(time.RFC3339Nano) + } + raw, err := json.Marshal(payload) + if err != nil { + return "{}" + } + return string(raw) +} + +func (s *OpsService) GetSystemLogSinkHealth() OpsSystemLogSinkHealth { + if s == nil || s.systemLogSink == nil { + return OpsSystemLogSinkHealth{} + } + return s.systemLogSink.Health() +} diff --git a/backend/internal/service/ops_system_log_service_test.go b/backend/internal/service/ops_system_log_service_test.go new file mode 100644 index 00000000..cc9ddefe --- /dev/null +++ b/backend/internal/service/ops_system_log_service_test.go @@ -0,0 +1,243 @@ +package service + +import ( + "context" + "database/sql" + "errors" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +func TestOpsServiceListSystemLogs_DefaultClampAndSuccess(t *testing.T) { + var gotFilter *OpsSystemLogFilter + repo := &opsRepoMock{ + ListSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) { + gotFilter = filter + return &OpsSystemLogList{ + Logs: []*OpsSystemLog{{ID: 1, Level: "warn", Message: "x"}}, + Total: 1, + Page: filter.Page, + PageSize: filter.PageSize, + }, nil + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + + out, err := svc.ListSystemLogs(context.Background(), &OpsSystemLogFilter{ + Page: 0, + PageSize: 999, + }) + if err != nil { + t.Fatalf("ListSystemLogs() error: %v", err) + } + if gotFilter == nil { + t.Fatalf("expected repository to receive filter") + } + if gotFilter.Page != 1 || gotFilter.PageSize != 200 { + t.Fatalf("filter normalized unexpectedly: page=%d pageSize=%d", gotFilter.Page, gotFilter.PageSize) + } + if out.Total != 1 || len(out.Logs) != 1 { + t.Fatalf("unexpected result: %+v", out) + } +} + +func TestOpsServiceListSystemLogs_MonitoringDisabled(t *testing.T) { + svc := NewOpsService( + &opsRepoMock{}, + nil, + &config.Config{Ops: config.OpsConfig{Enabled: false}}, + nil, nil, nil, nil, nil, nil, nil, nil, + ) + _, err := svc.ListSystemLogs(context.Background(), &OpsSystemLogFilter{}) + if err == nil { + t.Fatalf("expected disabled error") + } +} + +func TestOpsServiceListSystemLogs_NilRepoReturnsEmpty(t *testing.T) { + svc := NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + out, err := svc.ListSystemLogs(context.Background(), nil) + if err != nil { + t.Fatalf("ListSystemLogs() error: %v", err) + } + if out == nil || out.Page != 1 || out.PageSize != 50 || out.Total != 0 || len(out.Logs) != 0 { + t.Fatalf("unexpected nil-repo result: %+v", out) + } +} + +func TestOpsServiceListSystemLogs_RepoErrorMapped(t *testing.T) { + repo := &opsRepoMock{ + ListSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) { + return nil, errors.New("db down") + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + _, err := svc.ListSystemLogs(context.Background(), &OpsSystemLogFilter{}) + if err == nil { + t.Fatalf("expected mapped internal error") + } + if !strings.Contains(err.Error(), "OPS_SYSTEM_LOG_LIST_FAILED") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestOpsServiceCleanupSystemLogs_SuccessAndAudit(t *testing.T) { + var audit *OpsSystemLogCleanupAudit + repo := &opsRepoMock{ + DeleteSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) { + return 3, nil + }, + InsertSystemLogCleanupAuditFn: func(ctx context.Context, input *OpsSystemLogCleanupAudit) error { + audit = input + return nil + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + userID := int64(7) + now := time.Now().UTC() + filter := &OpsSystemLogCleanupFilter{ + StartTime: &now, + Level: "warn", + RequestID: "req-1", + ClientRequestID: "creq-1", + UserID: &userID, + Query: "timeout", + } + + deleted, err := svc.CleanupSystemLogs(context.Background(), filter, 99) + if err != nil { + t.Fatalf("CleanupSystemLogs() error: %v", err) + } + if deleted != 3 { + t.Fatalf("deleted=%d, want 3", deleted) + } + if audit == nil { + t.Fatalf("expected cleanup audit") + } + if !strings.Contains(audit.Conditions, `"client_request_id":"creq-1"`) { + t.Fatalf("audit conditions should include client_request_id: %s", audit.Conditions) + } + if !strings.Contains(audit.Conditions, `"user_id":7`) { + t.Fatalf("audit conditions should include user_id: %s", audit.Conditions) + } +} + +func TestOpsServiceCleanupSystemLogs_RepoUnavailableAndInvalidOperator(t *testing.T) { + svc := NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + if _, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{RequestID: "r"}, 1); err == nil { + t.Fatalf("expected repo unavailable error") + } + + svc = NewOpsService(&opsRepoMock{}, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + if _, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{RequestID: "r"}, 0); err == nil { + t.Fatalf("expected invalid operator error") + } +} + +func TestOpsServiceCleanupSystemLogs_FilterRequired(t *testing.T) { + repo := &opsRepoMock{ + DeleteSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) { + return 0, errors.New("cleanup requires at least one filter condition") + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + _, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{}, 1) + if err == nil { + t.Fatalf("expected filter required error") + } + if !strings.Contains(strings.ToLower(err.Error()), "filter") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestOpsServiceCleanupSystemLogs_InvalidRange(t *testing.T) { + repo := &opsRepoMock{} + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + start := time.Now().UTC() + end := start.Add(-time.Hour) + _, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{ + StartTime: &start, + EndTime: &end, + }, 1) + if err == nil { + t.Fatalf("expected invalid range error") + } +} + +func TestOpsServiceCleanupSystemLogs_NoRowsAndInternalError(t *testing.T) { + repo := &opsRepoMock{ + DeleteSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) { + return 0, sql.ErrNoRows + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + deleted, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{ + RequestID: "req-1", + }, 1) + if err != nil || deleted != 0 { + t.Fatalf("expected no rows shortcut, deleted=%d err=%v", deleted, err) + } + + repo.DeleteSystemLogsFn = func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) { + return 0, errors.New("boom") + } + if _, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{ + RequestID: "req-1", + }, 1); err == nil { + t.Fatalf("expected internal cleanup error") + } +} + +func TestOpsServiceCleanupSystemLogs_AuditFailureIgnored(t *testing.T) { + repo := &opsRepoMock{ + DeleteSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) { + return 5, nil + }, + InsertSystemLogCleanupAuditFn: func(ctx context.Context, input *OpsSystemLogCleanupAudit) error { + return errors.New("audit down") + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + deleted, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{ + RequestID: "r1", + }, 1) + if err != nil || deleted != 5 { + t.Fatalf("audit failure should not break cleanup, deleted=%d err=%v", deleted, err) + } +} + +func TestMarshalSystemLogCleanupConditions_NilAndMarshalError(t *testing.T) { + if got := marshalSystemLogCleanupConditions(nil); got != "{}" { + t.Fatalf("nil filter should return {}, got %s", got) + } + + now := time.Now().UTC() + userID := int64(1) + filter := &OpsSystemLogCleanupFilter{ + StartTime: &now, + EndTime: &now, + UserID: &userID, + } + got := marshalSystemLogCleanupConditions(filter) + if !strings.Contains(got, `"start_time"`) || !strings.Contains(got, `"user_id":1`) { + t.Fatalf("unexpected marshal payload: %s", got) + } +} + +func TestOpsServiceGetSystemLogSinkHealth(t *testing.T) { + svc := NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + health := svc.GetSystemLogSinkHealth() + if health.QueueCapacity != 0 || health.QueueDepth != 0 { + t.Fatalf("unexpected health for nil sink: %+v", health) + } + + sink := NewOpsSystemLogSink(&opsRepoMock{}) + svc = NewOpsService(&opsRepoMock{}, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink) + health = svc.GetSystemLogSinkHealth() + if health.QueueCapacity <= 0 { + t.Fatalf("expected non-zero queue capacity: %+v", health) + } +} diff --git a/backend/internal/service/ops_system_log_sink.go b/backend/internal/service/ops_system_log_sink.go new file mode 100644 index 00000000..c50a30d5 --- /dev/null +++ b/backend/internal/service/ops_system_log_sink.go @@ -0,0 +1,335 @@ +package service + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/logredact" +) + +type OpsSystemLogSinkHealth struct { + QueueDepth int64 `json:"queue_depth"` + QueueCapacity int64 `json:"queue_capacity"` + DroppedCount uint64 `json:"dropped_count"` + WriteFailed uint64 `json:"write_failed_count"` + WrittenCount uint64 `json:"written_count"` + AvgWriteDelayMs uint64 `json:"avg_write_delay_ms"` + LastError string `json:"last_error"` +} + +type OpsSystemLogSink struct { + opsRepo OpsRepository + + queue chan *logger.LogEvent + + batchSize int + flushInterval time.Duration + + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + + droppedCount uint64 + writeFailed uint64 + writtenCount uint64 + totalDelayNs uint64 + + lastError atomic.Value +} + +func NewOpsSystemLogSink(opsRepo OpsRepository) *OpsSystemLogSink { + ctx, cancel := context.WithCancel(context.Background()) + s := &OpsSystemLogSink{ + opsRepo: opsRepo, + queue: make(chan *logger.LogEvent, 5000), + batchSize: 200, + flushInterval: time.Second, + ctx: ctx, + cancel: cancel, + } + s.lastError.Store("") + return s +} + +func (s *OpsSystemLogSink) Start() { + if s == nil || s.opsRepo == nil { + return + } + s.wg.Add(1) + go s.run() +} + +func (s *OpsSystemLogSink) Stop() { + if s == nil { + return + } + s.cancel() + s.wg.Wait() +} + +func (s *OpsSystemLogSink) WriteLogEvent(event *logger.LogEvent) { + if s == nil || event == nil || !s.shouldIndex(event) { + return + } + if s.ctx != nil { + select { + case <-s.ctx.Done(): + return + default: + } + } + + select { + case s.queue <- event: + default: + atomic.AddUint64(&s.droppedCount, 1) + } +} + +func (s *OpsSystemLogSink) shouldIndex(event *logger.LogEvent) bool { + level := strings.ToLower(strings.TrimSpace(event.Level)) + switch level { + case "warn", "warning", "error", "fatal", "panic", "dpanic": + return true + } + + component := strings.ToLower(strings.TrimSpace(event.Component)) + // zap 的 LoggerName 往往为空或不等于业务组件名;业务组件名通常以字段 component 透传。 + if event.Fields != nil { + if fc := strings.ToLower(strings.TrimSpace(asString(event.Fields["component"]))); fc != "" { + component = fc + } + } + if strings.Contains(component, "http.access") { + return true + } + if strings.Contains(component, "audit") { + return true + } + return false +} + +func (s *OpsSystemLogSink) run() { + defer s.wg.Done() + + ticker := time.NewTicker(s.flushInterval) + defer ticker.Stop() + + batch := make([]*logger.LogEvent, 0, s.batchSize) + flush := func(baseCtx context.Context) { + if len(batch) == 0 { + return + } + started := time.Now() + inserted, err := s.flushBatch(baseCtx, batch) + delay := time.Since(started) + if err != nil { + atomic.AddUint64(&s.writeFailed, uint64(len(batch))) + s.lastError.Store(err.Error()) + _, _ = fmt.Fprintf(os.Stderr, "time=%s level=WARN msg=\"ops system log sink flush failed\" err=%v batch=%d\n", + time.Now().Format(time.RFC3339Nano), err, len(batch), + ) + } else { + atomic.AddUint64(&s.writtenCount, uint64(inserted)) + atomic.AddUint64(&s.totalDelayNs, uint64(delay.Nanoseconds())) + s.lastError.Store("") + } + batch = batch[:0] + } + drainAndFlush := func() { + for { + select { + case item := <-s.queue: + if item == nil { + continue + } + batch = append(batch, item) + if len(batch) >= s.batchSize { + flush(context.Background()) + } + default: + flush(context.Background()) + return + } + } + } + + for { + select { + case <-s.ctx.Done(): + drainAndFlush() + return + case item := <-s.queue: + if item == nil { + continue + } + batch = append(batch, item) + if len(batch) >= s.batchSize { + flush(s.ctx) + } + case <-ticker.C: + flush(s.ctx) + } + } +} + +func (s *OpsSystemLogSink) flushBatch(baseCtx context.Context, batch []*logger.LogEvent) (int, error) { + inputs := make([]*OpsInsertSystemLogInput, 0, len(batch)) + for _, event := range batch { + if event == nil { + continue + } + createdAt := event.Time.UTC() + if createdAt.IsZero() { + createdAt = time.Now().UTC() + } + + fields := copyMap(event.Fields) + requestID := asString(fields["request_id"]) + clientRequestID := asString(fields["client_request_id"]) + platform := asString(fields["platform"]) + model := asString(fields["model"]) + component := strings.TrimSpace(event.Component) + if fieldComponent := asString(fields["component"]); fieldComponent != "" { + component = fieldComponent + } + if component == "" { + component = "app" + } + + userID := asInt64Ptr(fields["user_id"]) + accountID := asInt64Ptr(fields["account_id"]) + + // 统一脱敏后写入索引。 + message := logredact.RedactText(strings.TrimSpace(event.Message)) + redactedExtra := logredact.RedactMap(fields) + extraJSONBytes, _ := json.Marshal(redactedExtra) + extraJSON := string(extraJSONBytes) + if strings.TrimSpace(extraJSON) == "" { + extraJSON = "{}" + } + + inputs = append(inputs, &OpsInsertSystemLogInput{ + CreatedAt: createdAt, + Level: strings.ToLower(strings.TrimSpace(event.Level)), + Component: component, + Message: message, + RequestID: requestID, + ClientRequestID: clientRequestID, + UserID: userID, + AccountID: accountID, + Platform: platform, + Model: model, + ExtraJSON: extraJSON, + }) + } + + if len(inputs) == 0 { + return 0, nil + } + if baseCtx == nil || baseCtx.Err() != nil { + baseCtx = context.Background() + } + ctx, cancel := context.WithTimeout(baseCtx, 5*time.Second) + defer cancel() + inserted, err := s.opsRepo.BatchInsertSystemLogs(ctx, inputs) + if err != nil { + return 0, err + } + return int(inserted), nil +} + +func (s *OpsSystemLogSink) Health() OpsSystemLogSinkHealth { + if s == nil { + return OpsSystemLogSinkHealth{} + } + written := atomic.LoadUint64(&s.writtenCount) + totalDelay := atomic.LoadUint64(&s.totalDelayNs) + var avgDelay uint64 + if written > 0 { + avgDelay = (totalDelay / written) / uint64(time.Millisecond) + } + + lastErr, _ := s.lastError.Load().(string) + return OpsSystemLogSinkHealth{ + QueueDepth: int64(len(s.queue)), + QueueCapacity: int64(cap(s.queue)), + DroppedCount: atomic.LoadUint64(&s.droppedCount), + WriteFailed: atomic.LoadUint64(&s.writeFailed), + WrittenCount: written, + AvgWriteDelayMs: avgDelay, + LastError: strings.TrimSpace(lastErr), + } +} + +func copyMap(in map[string]any) map[string]any { + if len(in) == 0 { + return map[string]any{} + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func asString(v any) string { + switch t := v.(type) { + case string: + return strings.TrimSpace(t) + case fmt.Stringer: + return strings.TrimSpace(t.String()) + default: + return "" + } +} + +func asInt64Ptr(v any) *int64 { + switch t := v.(type) { + case int: + n := int64(t) + if n <= 0 { + return nil + } + return &n + case int64: + n := t + if n <= 0 { + return nil + } + return &n + case float64: + n := int64(t) + if n <= 0 { + return nil + } + return &n + case json.Number: + if n, err := t.Int64(); err == nil { + if n <= 0 { + return nil + } + return &n + } + case string: + raw := strings.TrimSpace(t) + if raw == "" { + return nil + } + if n, err := strconv.ParseInt(raw, 10, 64); err == nil { + if n <= 0 { + return nil + } + return &n + } + } + return nil +} diff --git a/backend/internal/service/ops_system_log_sink_test.go b/backend/internal/service/ops_system_log_sink_test.go new file mode 100644 index 00000000..12a2ec0c --- /dev/null +++ b/backend/internal/service/ops_system_log_sink_test.go @@ -0,0 +1,313 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +func TestOpsSystemLogSink_ShouldIndex(t *testing.T) { + sink := &OpsSystemLogSink{} + + cases := []struct { + name string + event *logger.LogEvent + want bool + }{ + { + name: "warn level", + event: &logger.LogEvent{Level: "warn", Component: "app"}, + want: true, + }, + { + name: "error level", + event: &logger.LogEvent{Level: "error", Component: "app"}, + want: true, + }, + { + name: "access component", + event: &logger.LogEvent{Level: "info", Component: "http.access"}, + want: true, + }, + { + name: "access component from fields (real zap path)", + event: &logger.LogEvent{ + Level: "info", + Component: "", + Fields: map[string]any{"component": "http.access"}, + }, + want: true, + }, + { + name: "audit component", + event: &logger.LogEvent{Level: "info", Component: "audit.log_config_change"}, + want: true, + }, + { + name: "audit component from fields (real zap path)", + event: &logger.LogEvent{ + Level: "info", + Component: "", + Fields: map[string]any{"component": "audit.log_config_change"}, + }, + want: true, + }, + { + name: "plain info", + event: &logger.LogEvent{Level: "info", Component: "app"}, + want: false, + }, + } + + for _, tc := range cases { + if got := sink.shouldIndex(tc.event); got != tc.want { + t.Fatalf("%s: shouldIndex()=%v, want %v", tc.name, got, tc.want) + } + } +} + +func TestOpsSystemLogSink_WriteLogEvent_ShouldDropWhenQueueFull(t *testing.T) { + sink := &OpsSystemLogSink{ + queue: make(chan *logger.LogEvent, 1), + } + + sink.WriteLogEvent(&logger.LogEvent{Level: "warn", Component: "app"}) + sink.WriteLogEvent(&logger.LogEvent{Level: "warn", Component: "app"}) + + if got := len(sink.queue); got != 1 { + t.Fatalf("queue len = %d, want 1", got) + } + if dropped := atomic.LoadUint64(&sink.droppedCount); dropped != 1 { + t.Fatalf("droppedCount = %d, want 1", dropped) + } +} + +func TestOpsSystemLogSink_Health(t *testing.T) { + sink := &OpsSystemLogSink{ + queue: make(chan *logger.LogEvent, 10), + } + sink.lastError.Store("db timeout") + atomic.StoreUint64(&sink.droppedCount, 3) + atomic.StoreUint64(&sink.writeFailed, 2) + atomic.StoreUint64(&sink.writtenCount, 5) + atomic.StoreUint64(&sink.totalDelayNs, uint64(5000000)) // 5ms total -> avg 1ms + sink.queue <- &logger.LogEvent{Level: "warn", Component: "app"} + sink.queue <- &logger.LogEvent{Level: "warn", Component: "app"} + + health := sink.Health() + if health.QueueDepth != 2 { + t.Fatalf("queue depth = %d, want 2", health.QueueDepth) + } + if health.QueueCapacity != 10 { + t.Fatalf("queue capacity = %d, want 10", health.QueueCapacity) + } + if health.DroppedCount != 3 { + t.Fatalf("dropped = %d, want 3", health.DroppedCount) + } + if health.WriteFailed != 2 { + t.Fatalf("write failed = %d, want 2", health.WriteFailed) + } + if health.WrittenCount != 5 { + t.Fatalf("written = %d, want 5", health.WrittenCount) + } + if health.AvgWriteDelayMs != 1 { + t.Fatalf("avg delay ms = %d, want 1", health.AvgWriteDelayMs) + } + if health.LastError != "db timeout" { + t.Fatalf("last error = %q, want db timeout", health.LastError) + } +} + +func TestOpsSystemLogSink_StartStopAndFlushSuccess(t *testing.T) { + done := make(chan struct{}, 1) + var captured []*OpsInsertSystemLogInput + repo := &opsRepoMock{ + BatchInsertSystemLogsFn: func(_ context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) { + captured = append(captured, inputs...) + select { + case done <- struct{}{}: + default: + } + return int64(len(inputs)), nil + }, + } + + sink := NewOpsSystemLogSink(repo) + sink.batchSize = 1 + sink.flushInterval = 10 * time.Millisecond + sink.Start() + defer sink.Stop() + + sink.WriteLogEvent(&logger.LogEvent{ + Time: time.Now().UTC(), + Level: "warn", + Component: "http.access", + Message: `authorization="Bearer sk-test-123"`, + Fields: map[string]any{ + "component": "http.access", + "request_id": "req-1", + "client_request_id": "creq-1", + "user_id": "12", + "account_id": json.Number("34"), + "platform": "openai", + "model": "gpt-5", + }, + }) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for sink flush") + } + + if len(captured) != 1 { + t.Fatalf("captured len = %d, want 1", len(captured)) + } + item := captured[0] + if item.RequestID != "req-1" || item.ClientRequestID != "creq-1" { + t.Fatalf("unexpected request ids: %+v", item) + } + if item.UserID == nil || *item.UserID != 12 { + t.Fatalf("unexpected user_id: %+v", item.UserID) + } + if item.AccountID == nil || *item.AccountID != 34 { + t.Fatalf("unexpected account_id: %+v", item.AccountID) + } + if strings.TrimSpace(item.Message) == "" { + t.Fatalf("message should not be empty") + } + health := sink.Health() + if health.WrittenCount == 0 { + t.Fatalf("written_count should be >0") + } +} + +func TestOpsSystemLogSink_FlushFailureUpdatesHealth(t *testing.T) { + repo := &opsRepoMock{ + BatchInsertSystemLogsFn: func(_ context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) { + return 0, errors.New("db unavailable") + }, + } + sink := NewOpsSystemLogSink(repo) + sink.batchSize = 1 + sink.flushInterval = 10 * time.Millisecond + sink.Start() + defer sink.Stop() + + sink.WriteLogEvent(&logger.LogEvent{ + Time: time.Now().UTC(), + Level: "warn", + Component: "app", + Message: "boom", + Fields: map[string]any{}, + }) + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + health := sink.Health() + if health.WriteFailed > 0 { + if !strings.Contains(health.LastError, "db unavailable") { + t.Fatalf("unexpected last error: %s", health.LastError) + } + return + } + time.Sleep(20 * time.Millisecond) + } + t.Fatalf("write_failed_count not updated") +} + +func TestOpsSystemLogSink_StopFlushUsesActiveContextAndDrainsQueue(t *testing.T) { + var inserted int64 + var canceledCtxCalls int64 + repo := &opsRepoMock{ + BatchInsertSystemLogsFn: func(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) { + if err := ctx.Err(); err != nil { + atomic.AddInt64(&canceledCtxCalls, 1) + return 0, err + } + atomic.AddInt64(&inserted, int64(len(inputs))) + return int64(len(inputs)), nil + }, + } + + sink := NewOpsSystemLogSink(repo) + sink.batchSize = 200 + sink.flushInterval = time.Hour + sink.Start() + + sink.WriteLogEvent(&logger.LogEvent{ + Time: time.Now().UTC(), + Level: "warn", + Component: "app", + Message: "pending-on-shutdown", + Fields: map[string]any{"component": "http.access"}, + }) + + sink.Stop() + + if got := atomic.LoadInt64(&inserted); got != 1 { + t.Fatalf("inserted = %d, want 1", got) + } + if got := atomic.LoadInt64(&canceledCtxCalls); got != 0 { + t.Fatalf("canceled ctx calls = %d, want 0", got) + } + health := sink.Health() + if health.WrittenCount != 1 { + t.Fatalf("written_count = %d, want 1", health.WrittenCount) + } +} + +type stringerValue string + +func (s stringerValue) String() string { return string(s) } + +func TestOpsSystemLogSink_HelperFunctions(t *testing.T) { + src := map[string]any{"a": 1} + cloned := copyMap(src) + src["a"] = 2 + v, ok := cloned["a"].(int) + if !ok || v != 1 { + t.Fatalf("copyMap should create copy") + } + if got := asString(stringerValue(" hello ")); got != "hello" { + t.Fatalf("asString stringer = %q", got) + } + if got := asString(fmt.Errorf("x")); got != "" { + t.Fatalf("asString error should be empty, got %q", got) + } + if got := asString(123); got != "" { + t.Fatalf("asString non-string should be empty, got %q", got) + } + + cases := []struct { + in any + want int64 + ok bool + }{ + {in: 5, want: 5, ok: true}, + {in: int64(6), want: 6, ok: true}, + {in: float64(7), want: 7, ok: true}, + {in: json.Number("8"), want: 8, ok: true}, + {in: "9", want: 9, ok: true}, + {in: "0", ok: false}, + {in: -1, ok: false}, + {in: "abc", ok: false}, + } + for _, tc := range cases { + got := asInt64Ptr(tc.in) + if tc.ok { + if got == nil || *got != tc.want { + t.Fatalf("asInt64Ptr(%v) = %+v, want %d", tc.in, got, tc.want) + } + } else if got != nil { + t.Fatalf("asInt64Ptr(%v) should be nil, got %d", tc.in, *got) + } + } +} diff --git a/backend/internal/service/ops_trends.go b/backend/internal/service/ops_trends.go index ec55c6ce..22db72ef 100644 --- a/backend/internal/service/ops_trends.go +++ b/backend/internal/service/ops_trends.go @@ -22,5 +22,13 @@ func (s *OpsService) GetThroughputTrend(ctx context.Context, filter *OpsDashboar if filter.StartTime.After(filter.EndTime) { return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time") } - return s.opsRepo.GetThroughputTrend(ctx, filter, bucketSeconds) + + filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode) + + result, err := s.opsRepo.GetThroughputTrend(ctx, filter, bucketSeconds) + if err != nil && shouldFallbackOpsPreagg(filter, err) { + rawFilter := cloneOpsFilterWithMode(filter, OpsQueryModeRaw) + return s.opsRepo.GetThroughputTrend(ctx, rawFilter, bucketSeconds) + } + return result, err } diff --git a/backend/internal/service/ops_upstream_context.go b/backend/internal/service/ops_upstream_context.go index 3514df79..21e09c43 100644 --- a/backend/internal/service/ops_upstream_context.go +++ b/backend/internal/service/ops_upstream_context.go @@ -21,11 +21,38 @@ const ( // This value is sanitized+trimmed before being persisted. OpsUpstreamRequestBodyKey = "ops_upstream_request_body" + // Optional stage latencies (milliseconds) for troubleshooting and alerting. + OpsAuthLatencyMsKey = "ops_auth_latency_ms" + OpsRoutingLatencyMsKey = "ops_routing_latency_ms" + OpsUpstreamLatencyMsKey = "ops_upstream_latency_ms" + OpsResponseLatencyMsKey = "ops_response_latency_ms" + OpsTimeToFirstTokenMsKey = "ops_time_to_first_token_ms" + // OpenAI WS 关键观测字段 + OpsOpenAIWSQueueWaitMsKey = "ops_openai_ws_queue_wait_ms" + OpsOpenAIWSConnPickMsKey = "ops_openai_ws_conn_pick_ms" + OpsOpenAIWSConnReusedKey = "ops_openai_ws_conn_reused" + OpsOpenAIWSConnIDKey = "ops_openai_ws_conn_id" + // OpsSkipPassthroughKey 由 applyErrorPassthroughRule 在命中 skip_monitoring=true 的规则时设置。 // ops_error_logger 中间件检查此 key,为 true 时跳过错误记录。 OpsSkipPassthroughKey = "ops_skip_passthrough" ) +func setOpsUpstreamRequestBody(c *gin.Context, body []byte) { + if c == nil || len(body) == 0 { + return + } + // 热路径避免 string(body) 额外分配,按需在落库前再转换。 + c.Set(OpsUpstreamRequestBodyKey, body) +} + +func SetOpsLatencyMs(c *gin.Context, key string, value int64) { + if c == nil || strings.TrimSpace(key) == "" || value < 0 { + return + } + c.Set(key, value) +} + func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) { if c == nil { return @@ -46,6 +73,10 @@ func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage type OpsUpstreamErrorEvent struct { AtUnixMs int64 `json:"at_unix_ms,omitempty"` + // Passthrough 表示本次请求是否命中“原样透传(仅替换认证)”分支。 + // 该字段用于排障与灰度评估;存入 JSON,不涉及 DB schema 变更。 + Passthrough bool `json:"passthrough,omitempty"` + // Context Platform string `json:"platform,omitempty"` AccountID int64 `json:"account_id,omitempty"` @@ -91,8 +122,11 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) { // stored it on the context, attach it so ops can retry this specific attempt. if ev.UpstreamRequestBody == "" { if v, ok := c.Get(OpsUpstreamRequestBodyKey); ok { - if s, ok := v.(string); ok { - ev.UpstreamRequestBody = strings.TrimSpace(s) + switch raw := v.(type) { + case string: + ev.UpstreamRequestBody = strings.TrimSpace(raw) + case []byte: + ev.UpstreamRequestBody = strings.TrimSpace(string(raw)) } } } diff --git a/backend/internal/service/ops_upstream_context_test.go b/backend/internal/service/ops_upstream_context_test.go new file mode 100644 index 00000000..50ceaa0e --- /dev/null +++ b/backend/internal/service/ops_upstream_context_test.go @@ -0,0 +1,47 @@ +package service + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestAppendOpsUpstreamError_UsesRequestBodyBytesFromContext(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + setOpsUpstreamRequestBody(c, []byte(`{"model":"gpt-5"}`)) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Kind: "http_error", + Message: "upstream failed", + }) + + v, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + events, ok := v.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.Len(t, events, 1) + require.Equal(t, `{"model":"gpt-5"}`, events[0].UpstreamRequestBody) +} + +func TestAppendOpsUpstreamError_UsesRequestBodyStringFromContext(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + c.Set(OpsUpstreamRequestBodyKey, `{"model":"gpt-4"}`) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Kind: "request_error", + Message: "dial timeout", + }) + + v, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + events, ok := v.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.Len(t, events, 1) + require.Equal(t, `{"model":"gpt-4"}`, events[0].UpstreamRequestBody) +} diff --git a/backend/internal/service/parse_integral_number_unit.go b/backend/internal/service/parse_integral_number_unit.go new file mode 100644 index 00000000..c9c617b1 --- /dev/null +++ b/backend/internal/service/parse_integral_number_unit.go @@ -0,0 +1,51 @@ +//go:build unit + +package service + +import ( + "encoding/json" + "math" +) + +// parseIntegralNumber 将 JSON 解码后的数字安全转换为 int。 +// 仅接受“整数值”的输入,小数/NaN/Inf/越界值都会返回 false。 +// +// 说明: +// - 该函数当前仅用于 unit 测试中的 map-based 解析逻辑验证,因此放在 unit build tag 下, +// 避免在默认构建中触发 unused lint。 +func parseIntegralNumber(raw any) (int, bool) { + switch v := raw.(type) { + case float64: + if math.IsNaN(v) || math.IsInf(v, 0) || v != math.Trunc(v) { + return 0, false + } + if v > float64(math.MaxInt) || v < float64(math.MinInt) { + return 0, false + } + return int(v), true + case int: + return v, true + case int8: + return int(v), true + case int16: + return int(v), true + case int32: + return int(v), true + case int64: + if v > int64(math.MaxInt) || v < int64(math.MinInt) { + return 0, false + } + return int(v), true + case json.Number: + i64, err := v.Int64() + if err != nil { + return 0, false + } + if i64 > int64(math.MaxInt) || i64 < int64(math.MinInt) { + return 0, false + } + return int(i64), true + default: + return 0, false + } +} diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index d8db0d67..41e8b5eb 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -6,7 +6,6 @@ import ( "encoding/hex" "encoding/json" "fmt" - "log" "os" "path/filepath" "regexp" @@ -15,8 +14,10 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" + "go.uber.org/zap" ) var ( @@ -27,14 +28,15 @@ var ( // LiteLLMModelPricing LiteLLM价格数据结构 // 只保留我们需要的字段,使用指针来处理可能缺失的值 type LiteLLMModelPricing struct { - InputCostPerToken float64 `json:"input_cost_per_token"` - OutputCostPerToken float64 `json:"output_cost_per_token"` - CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"` - CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"` - LiteLLMProvider string `json:"litellm_provider"` - Mode string `json:"mode"` - SupportsPromptCaching bool `json:"supports_prompt_caching"` - OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格 + InputCostPerToken float64 `json:"input_cost_per_token"` + OutputCostPerToken float64 `json:"output_cost_per_token"` + CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"` + CacheCreationInputTokenCostAbove1hr float64 `json:"cache_creation_input_token_cost_above_1hr"` + CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"` + LiteLLMProvider string `json:"litellm_provider"` + Mode string `json:"mode"` + SupportsPromptCaching bool `json:"supports_prompt_caching"` + OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格 } // PricingRemoteClient 远程价格数据获取接口 @@ -45,14 +47,15 @@ type PricingRemoteClient interface { // LiteLLMRawEntry 用于解析原始JSON数据 type LiteLLMRawEntry struct { - InputCostPerToken *float64 `json:"input_cost_per_token"` - OutputCostPerToken *float64 `json:"output_cost_per_token"` - CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"` - CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"` - LiteLLMProvider string `json:"litellm_provider"` - Mode string `json:"mode"` - SupportsPromptCaching bool `json:"supports_prompt_caching"` - OutputCostPerImage *float64 `json:"output_cost_per_image"` + InputCostPerToken *float64 `json:"input_cost_per_token"` + OutputCostPerToken *float64 `json:"output_cost_per_token"` + CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"` + CacheCreationInputTokenCostAbove1hr *float64 `json:"cache_creation_input_token_cost_above_1hr"` + CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"` + LiteLLMProvider string `json:"litellm_provider"` + Mode string `json:"mode"` + SupportsPromptCaching bool `json:"supports_prompt_caching"` + OutputCostPerImage *float64 `json:"output_cost_per_image"` } // PricingService 动态价格服务 @@ -84,12 +87,12 @@ func NewPricingService(cfg *config.Config, remoteClient PricingRemoteClient) *Pr func (s *PricingService) Initialize() error { // 确保数据目录存在 if err := os.MkdirAll(s.cfg.Pricing.DataDir, 0755); err != nil { - log.Printf("[Pricing] Failed to create data directory: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to create data directory: %v", err) } // 首次加载价格数据 if err := s.checkAndUpdatePricing(); err != nil { - log.Printf("[Pricing] Initial load failed, using fallback: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Initial load failed, using fallback: %v", err) if err := s.useFallbackPricing(); err != nil { return fmt.Errorf("failed to load pricing data: %w", err) } @@ -98,7 +101,7 @@ func (s *PricingService) Initialize() error { // 启动定时更新 s.startUpdateScheduler() - log.Printf("[Pricing] Service initialized with %d models", len(s.pricingData)) + logger.LegacyPrintf("service.pricing", "[Pricing] Service initialized with %d models", len(s.pricingData)) return nil } @@ -106,7 +109,7 @@ func (s *PricingService) Initialize() error { func (s *PricingService) Stop() { close(s.stopCh) s.wg.Wait() - log.Println("[Pricing] Service stopped") + logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Service stopped") } // startUpdateScheduler 启动定时更新调度器 @@ -127,7 +130,7 @@ func (s *PricingService) startUpdateScheduler() { select { case <-ticker.C: if err := s.syncWithRemote(); err != nil { - log.Printf("[Pricing] Sync failed: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Sync failed: %v", err) } case <-s.stopCh: return @@ -135,7 +138,7 @@ func (s *PricingService) startUpdateScheduler() { } }() - log.Printf("[Pricing] Update scheduler started (check every %v)", hashInterval) + logger.LegacyPrintf("service.pricing", "[Pricing] Update scheduler started (check every %v)", hashInterval) } // checkAndUpdatePricing 检查并更新价格数据 @@ -144,7 +147,7 @@ func (s *PricingService) checkAndUpdatePricing() error { // 检查本地文件是否存在 if _, err := os.Stat(pricingFile); os.IsNotExist(err) { - log.Println("[Pricing] Local pricing file not found, downloading...") + logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Local pricing file not found, downloading...") return s.downloadPricingData() } @@ -158,9 +161,9 @@ func (s *PricingService) checkAndUpdatePricing() error { maxAge := time.Duration(s.cfg.Pricing.UpdateIntervalHours) * time.Hour if fileAge > maxAge { - log.Printf("[Pricing] Local file is %v old, updating...", fileAge.Round(time.Hour)) + logger.LegacyPrintf("service.pricing", "[Pricing] Local file is %v old, updating...", fileAge.Round(time.Hour)) if err := s.downloadPricingData(); err != nil { - log.Printf("[Pricing] Download failed, using existing file: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Download failed, using existing file: %v", err) } } @@ -175,7 +178,7 @@ func (s *PricingService) syncWithRemote() error { // 计算本地文件哈希 localHash, err := s.computeFileHash(pricingFile) if err != nil { - log.Printf("[Pricing] Failed to compute local hash: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to compute local hash: %v", err) return s.downloadPricingData() } @@ -183,15 +186,15 @@ func (s *PricingService) syncWithRemote() error { if s.cfg.Pricing.HashURL != "" { remoteHash, err := s.fetchRemoteHash() if err != nil { - log.Printf("[Pricing] Failed to fetch remote hash: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to fetch remote hash: %v", err) return nil // 哈希获取失败不影响正常使用 } if remoteHash != localHash { - log.Println("[Pricing] Remote hash differs, downloading new version...") + logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Remote hash differs, downloading new version...") return s.downloadPricingData() } - log.Println("[Pricing] Hash check passed, no update needed") + logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Hash check passed, no update needed") return nil } @@ -205,7 +208,7 @@ func (s *PricingService) syncWithRemote() error { maxAge := time.Duration(s.cfg.Pricing.UpdateIntervalHours) * time.Hour if fileAge > maxAge { - log.Printf("[Pricing] File is %v old, downloading...", fileAge.Round(time.Hour)) + logger.LegacyPrintf("service.pricing", "[Pricing] File is %v old, downloading...", fileAge.Round(time.Hour)) return s.downloadPricingData() } @@ -218,7 +221,7 @@ func (s *PricingService) downloadPricingData() error { if err != nil { return err } - log.Printf("[Pricing] Downloading from %s", remoteURL) + logger.LegacyPrintf("service.pricing", "[Pricing] Downloading from %s", remoteURL) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -252,7 +255,7 @@ func (s *PricingService) downloadPricingData() error { // 保存到本地文件 pricingFile := s.getPricingFilePath() if err := os.WriteFile(pricingFile, body, 0644); err != nil { - log.Printf("[Pricing] Failed to save file: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save file: %v", err) } // 保存哈希 @@ -260,7 +263,7 @@ func (s *PricingService) downloadPricingData() error { hashStr := hex.EncodeToString(hash[:]) hashFile := s.getHashFilePath() if err := os.WriteFile(hashFile, []byte(hashStr+"\n"), 0644); err != nil { - log.Printf("[Pricing] Failed to save hash: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save hash: %v", err) } // 更新内存数据 @@ -270,7 +273,7 @@ func (s *PricingService) downloadPricingData() error { s.localHash = hashStr s.mu.Unlock() - log.Printf("[Pricing] Downloaded %d models successfully", len(data)) + logger.LegacyPrintf("service.pricing", "[Pricing] Downloaded %d models successfully", len(data)) return nil } @@ -318,6 +321,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel if entry.CacheCreationInputTokenCost != nil { pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost } + if entry.CacheCreationInputTokenCostAbove1hr != nil { + pricing.CacheCreationInputTokenCostAbove1hr = *entry.CacheCreationInputTokenCostAbove1hr + } if entry.CacheReadInputTokenCost != nil { pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost } @@ -329,7 +335,7 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel } if skipped > 0 { - log.Printf("[Pricing] Skipped %d invalid entries", skipped) + logger.LegacyPrintf("service.pricing", "[Pricing] Skipped %d invalid entries", skipped) } if len(result) == 0 { @@ -368,7 +374,7 @@ func (s *PricingService) loadPricingData(filePath string) error { } s.mu.Unlock() - log.Printf("[Pricing] Loaded %d models from %s", len(pricingData), filePath) + logger.LegacyPrintf("service.pricing", "[Pricing] Loaded %d models from %s", len(pricingData), filePath) return nil } @@ -380,7 +386,7 @@ func (s *PricingService) useFallbackPricing() error { return fmt.Errorf("fallback file not found: %s", fallbackFile) } - log.Printf("[Pricing] Using fallback file: %s", fallbackFile) + logger.LegacyPrintf("service.pricing", "[Pricing] Using fallback file: %s", fallbackFile) // 复制到数据目录 data, err := os.ReadFile(fallbackFile) @@ -390,7 +396,7 @@ func (s *PricingService) useFallbackPricing() error { pricingFile := s.getPricingFilePath() if err := os.WriteFile(pricingFile, data, 0644); err != nil { - log.Printf("[Pricing] Failed to copy fallback: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to copy fallback: %v", err) } return s.loadPricingData(fallbackFile) @@ -639,7 +645,7 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing { for key, pricing := range s.pricingData { keyLower := strings.ToLower(key) if strings.Contains(keyLower, pattern) { - log.Printf("[Pricing] Fuzzy matched %s -> %s", model, key) + logger.LegacyPrintf("service.pricing", "[Pricing] Fuzzy matched %s -> %s", model, key) return pricing } } @@ -650,24 +656,36 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing { // matchOpenAIModel OpenAI 模型回退匹配策略 // 回退顺序: -// 1. gpt-5.2-codex -> gpt-5.2(去掉后缀如 -codex, -mini, -max 等) -// 2. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号) -// 3. gpt-5.3-codex -> gpt-5.2-codex -// 4. 最终回退到 DefaultTestModel (gpt-5.1-codex) +// 1. gpt-5.3-codex-spark* -> gpt-5.1-codex(按业务要求固定计费) +// 2. gpt-5.2-codex -> gpt-5.2(去掉后缀如 -codex, -mini, -max 等) +// 3. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号) +// 4. gpt-5.3-codex -> gpt-5.2-codex +// 5. 最终回退到 DefaultTestModel (gpt-5.1-codex) func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { + if strings.HasPrefix(model, "gpt-5.3-codex-spark") { + if pricing, ok := s.pricingData["gpt-5.1-codex"]; ok { + logger.LegacyPrintf("service.pricing", "[Pricing][SparkBilling] %s -> %s billing", model, "gpt-5.1-codex") + logger.With(zap.String("component", "service.pricing")). + Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.1-codex")) + return pricing + } + } + // 尝试的回退变体 variants := s.generateOpenAIModelVariants(model, openAIModelDatePattern) for _, variant := range variants { if pricing, ok := s.pricingData[variant]; ok { - log.Printf("[Pricing] OpenAI fallback matched %s -> %s", model, variant) + logger.With(zap.String("component", "service.pricing")). + Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, variant)) return pricing } } if strings.HasPrefix(model, "gpt-5.3-codex") { if pricing, ok := s.pricingData["gpt-5.2-codex"]; ok { - log.Printf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.2-codex") + logger.With(zap.String("component", "service.pricing")). + Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.2-codex")) return pricing } } @@ -675,7 +693,7 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { // 最终回退到 DefaultTestModel defaultModel := strings.ToLower(openai.DefaultTestModel) if pricing, ok := s.pricingData[defaultModel]; ok { - log.Printf("[Pricing] OpenAI fallback to default model %s -> %s", model, defaultModel) + logger.LegacyPrintf("service.pricing", "[Pricing] OpenAI fallback to default model %s -> %s", model, defaultModel) return pricing } diff --git a/backend/internal/service/pricing_service_test.go b/backend/internal/service/pricing_service_test.go new file mode 100644 index 00000000..127ff342 --- /dev/null +++ b/backend/internal/service/pricing_service_test.go @@ -0,0 +1,53 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetModelPricing_Gpt53CodexSparkUsesGpt51CodexPricing(t *testing.T) { + sparkPricing := &LiteLLMModelPricing{InputCostPerToken: 1} + gpt53Pricing := &LiteLLMModelPricing{InputCostPerToken: 9} + + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-5.1-codex": sparkPricing, + "gpt-5.3": gpt53Pricing, + }, + } + + got := svc.GetModelPricing("gpt-5.3-codex-spark") + require.Same(t, sparkPricing, got) +} + +func TestGetModelPricing_Gpt53CodexFallbackStillUsesGpt52Codex(t *testing.T) { + gpt52CodexPricing := &LiteLLMModelPricing{InputCostPerToken: 2} + + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-5.2-codex": gpt52CodexPricing, + }, + } + + got := svc.GetModelPricing("gpt-5.3-codex") + require.Same(t, gpt52CodexPricing, got) +} + +func TestGetModelPricing_OpenAIFallbackMatchedLoggedAsInfo(t *testing.T) { + logSink, restore := captureStructuredLog(t) + defer restore() + + gpt52CodexPricing := &LiteLLMModelPricing{InputCostPerToken: 2} + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-5.2-codex": gpt52CodexPricing, + }, + } + + got := svc.GetModelPricing("gpt-5.3-codex") + require.Same(t, gpt52CodexPricing, got) + + require.True(t, logSink.ContainsMessageAtLevel("[Pricing] OpenAI fallback matched gpt-5.3-codex -> gpt-5.2-codex", "info")) + require.False(t, logSink.ContainsMessageAtLevel("[Pricing] OpenAI fallback matched gpt-5.3-codex -> gpt-5.2-codex", "warn")) +} diff --git a/backend/internal/service/proxy.go b/backend/internal/service/proxy.go index 7eb7728f..fc449091 100644 --- a/backend/internal/service/proxy.go +++ b/backend/internal/service/proxy.go @@ -40,6 +40,11 @@ type ProxyWithAccountCount struct { CountryCode string Region string City string + QualityStatus string + QualityScore *int + QualityGrade string + QualitySummary string + QualityChecked *int64 } type ProxyAccountSummary struct { diff --git a/backend/internal/service/proxy_latency_cache.go b/backend/internal/service/proxy_latency_cache.go index 4a1cc77b..f54bff88 100644 --- a/backend/internal/service/proxy_latency_cache.go +++ b/backend/internal/service/proxy_latency_cache.go @@ -6,15 +6,21 @@ import ( ) type ProxyLatencyInfo struct { - Success bool `json:"success"` - LatencyMs *int64 `json:"latency_ms,omitempty"` - Message string `json:"message,omitempty"` - IPAddress string `json:"ip_address,omitempty"` - Country string `json:"country,omitempty"` - CountryCode string `json:"country_code,omitempty"` - Region string `json:"region,omitempty"` - City string `json:"city,omitempty"` - UpdatedAt time.Time `json:"updated_at"` + Success bool `json:"success"` + LatencyMs *int64 `json:"latency_ms,omitempty"` + Message string `json:"message,omitempty"` + IPAddress string `json:"ip_address,omitempty"` + Country string `json:"country,omitempty"` + CountryCode string `json:"country_code,omitempty"` + Region string `json:"region,omitempty"` + City string `json:"city,omitempty"` + QualityStatus string `json:"quality_status,omitempty"` + QualityScore *int `json:"quality_score,omitempty"` + QualityGrade string `json:"quality_grade,omitempty"` + QualitySummary string `json:"quality_summary,omitempty"` + QualityCheckedAt *int64 `json:"quality_checked_at,omitempty"` + QualityCFRay string `json:"quality_cf_ray,omitempty"` + UpdatedAt time.Time `json:"updated_at"` } type ProxyLatencyCache interface { diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index dab99b72..9f16fb2b 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -11,6 +11,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) // RateLimitService 处理限流和过载状态管理 @@ -33,6 +34,10 @@ type geminiUsageCacheEntry struct { totals GeminiUsageTotals } +type geminiUsageTotalsBatchProvider interface { + GetGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, startTime, endTime time.Time) (map[int64]GeminiUsageTotals, error) +} + const geminiPrecheckCacheTTL = time.Minute // NewRateLimitService 创建RateLimitService实例 @@ -141,13 +146,29 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc } else { slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform) } + // 3. 临时不可调度,替代 SetError(保持 status=active 让刷新服务能拾取) + msg := "Authentication failed (401): invalid or expired credentials" + if upstreamMsg != "" { + msg = "OAuth 401: " + upstreamMsg + } + cooldownMinutes := s.cfg.RateLimit.OAuth401CooldownMinutes + if cooldownMinutes <= 0 { + cooldownMinutes = 10 + } + until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute) + if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, msg); err != nil { + slog.Warn("oauth_401_set_temp_unschedulable_failed", "account_id", account.ID, "error", err) + } + shouldDisable = true + } else { + // 非 OAuth 账号(APIKey):保持原有 SetError 行为 + msg := "Authentication failed (401): invalid or expired credentials" + if upstreamMsg != "" { + msg = "Authentication failed (401): " + upstreamMsg + } + s.handleAuthError(ctx, account, msg) + shouldDisable = true } - msg := "Authentication failed (401): invalid or expired credentials" - if upstreamMsg != "" { - msg = "Authentication failed (401): " + upstreamMsg - } - s.handleAuthError(ctx, account, msg) - shouldDisable = true case 402: // 支付要求:余额不足或计费问题,停止调度 msg := "Payment required (402): insufficient balance or billing issue" @@ -162,6 +183,17 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc if upstreamMsg != "" { msg = "Access forbidden (403): " + upstreamMsg } + logger.LegacyPrintf( + "service.ratelimit", + "[HandleUpstreamErrorRaw] account_id=%d platform=%s type=%s status=403 request_id=%s cf_ray=%s upstream_msg=%s raw_body=%s", + account.ID, + account.Platform, + account.Type, + strings.TrimSpace(headers.Get("x-request-id")), + strings.TrimSpace(headers.Get("cf-ray")), + upstreamMsg, + truncateForLog(responseBody, 1024), + ) s.handleAuthError(ctx, account, msg) shouldDisable = true case 429: @@ -225,7 +257,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, start := geminiDailyWindowStart(now) totals, ok := s.getGeminiUsageTotals(account.ID, start, now) if !ok { - stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil) + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil, nil) if err != nil { return true, err } @@ -272,7 +304,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, if limit > 0 { start := now.Truncate(time.Minute) - stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil) + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil, nil) if err != nil { return true, err } @@ -302,6 +334,218 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, return true, nil } +// PreCheckUsageBatch performs quota precheck for multiple accounts in one request. +// Returned map value=false means the account should be skipped. +func (s *RateLimitService) PreCheckUsageBatch(ctx context.Context, accounts []*Account, requestedModel string) (map[int64]bool, error) { + result := make(map[int64]bool, len(accounts)) + for _, account := range accounts { + if account == nil { + continue + } + result[account.ID] = true + } + + if len(accounts) == 0 || requestedModel == "" { + return result, nil + } + if s.usageRepo == nil || s.geminiQuotaService == nil { + return result, nil + } + + modelClass := geminiModelClassFromName(requestedModel) + now := time.Now() + dailyStart := geminiDailyWindowStart(now) + minuteStart := now.Truncate(time.Minute) + + type quotaAccount struct { + account *Account + quota GeminiQuota + } + quotaAccounts := make([]quotaAccount, 0, len(accounts)) + for _, account := range accounts { + if account == nil || account.Platform != PlatformGemini { + continue + } + quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account) + if !ok { + continue + } + quotaAccounts = append(quotaAccounts, quotaAccount{ + account: account, + quota: quota, + }) + } + if len(quotaAccounts) == 0 { + return result, nil + } + + // 1) Daily precheck (cached + batch DB fallback) + dailyTotalsByID := make(map[int64]GeminiUsageTotals, len(quotaAccounts)) + dailyMissIDs := make([]int64, 0, len(quotaAccounts)) + for _, item := range quotaAccounts { + limit := geminiDailyLimit(item.quota, modelClass) + if limit <= 0 { + continue + } + accountID := item.account.ID + if totals, ok := s.getGeminiUsageTotals(accountID, dailyStart, now); ok { + dailyTotalsByID[accountID] = totals + continue + } + dailyMissIDs = append(dailyMissIDs, accountID) + } + if len(dailyMissIDs) > 0 { + totalsBatch, err := s.getGeminiUsageTotalsBatch(ctx, dailyMissIDs, dailyStart, now) + if err != nil { + return result, err + } + for _, accountID := range dailyMissIDs { + totals := totalsBatch[accountID] + dailyTotalsByID[accountID] = totals + s.setGeminiUsageTotals(accountID, dailyStart, now, totals) + } + } + for _, item := range quotaAccounts { + limit := geminiDailyLimit(item.quota, modelClass) + if limit <= 0 { + continue + } + accountID := item.account.ID + used := geminiUsedRequests(item.quota, modelClass, dailyTotalsByID[accountID], true) + if used >= limit { + resetAt := geminiDailyResetTime(now) + slog.Info("gemini_precheck_daily_quota_reached_batch", "account_id", accountID, "used", used, "limit", limit, "reset_at", resetAt) + result[accountID] = false + } + } + + // 2) Minute precheck (batch DB) + minuteIDs := make([]int64, 0, len(quotaAccounts)) + for _, item := range quotaAccounts { + accountID := item.account.ID + if !result[accountID] { + continue + } + if geminiMinuteLimit(item.quota, modelClass) <= 0 { + continue + } + minuteIDs = append(minuteIDs, accountID) + } + if len(minuteIDs) == 0 { + return result, nil + } + + minuteTotalsByID, err := s.getGeminiUsageTotalsBatch(ctx, minuteIDs, minuteStart, now) + if err != nil { + return result, err + } + for _, item := range quotaAccounts { + accountID := item.account.ID + if !result[accountID] { + continue + } + + limit := geminiMinuteLimit(item.quota, modelClass) + if limit <= 0 { + continue + } + + used := geminiUsedRequests(item.quota, modelClass, minuteTotalsByID[accountID], false) + if used >= limit { + resetAt := minuteStart.Add(time.Minute) + slog.Info("gemini_precheck_minute_quota_reached_batch", "account_id", accountID, "used", used, "limit", limit, "reset_at", resetAt) + result[accountID] = false + } + } + + return result, nil +} + +func (s *RateLimitService) getGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, start, end time.Time) (map[int64]GeminiUsageTotals, error) { + result := make(map[int64]GeminiUsageTotals, len(accountIDs)) + if len(accountIDs) == 0 { + return result, nil + } + + ids := make([]int64, 0, len(accountIDs)) + seen := make(map[int64]struct{}, len(accountIDs)) + for _, accountID := range accountIDs { + if accountID <= 0 { + continue + } + if _, ok := seen[accountID]; ok { + continue + } + seen[accountID] = struct{}{} + ids = append(ids, accountID) + } + if len(ids) == 0 { + return result, nil + } + + if batchReader, ok := s.usageRepo.(geminiUsageTotalsBatchProvider); ok { + stats, err := batchReader.GetGeminiUsageTotalsBatch(ctx, ids, start, end) + if err != nil { + return nil, err + } + for _, accountID := range ids { + result[accountID] = stats[accountID] + } + return result, nil + } + + for _, accountID := range ids { + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, end, 0, 0, accountID, 0, nil, nil, nil) + if err != nil { + return nil, err + } + result[accountID] = geminiAggregateUsage(stats) + } + return result, nil +} + +func geminiDailyLimit(quota GeminiQuota, modelClass geminiModelClass) int64 { + if quota.SharedRPD > 0 { + return quota.SharedRPD + } + switch modelClass { + case geminiModelFlash: + return quota.FlashRPD + default: + return quota.ProRPD + } +} + +func geminiMinuteLimit(quota GeminiQuota, modelClass geminiModelClass) int64 { + if quota.SharedRPM > 0 { + return quota.SharedRPM + } + switch modelClass { + case geminiModelFlash: + return quota.FlashRPM + default: + return quota.ProRPM + } +} + +func geminiUsedRequests(quota GeminiQuota, modelClass geminiModelClass, totals GeminiUsageTotals, daily bool) int64 { + if daily { + if quota.SharedRPD > 0 { + return totals.ProRequests + totals.FlashRequests + } + } else { + if quota.SharedRPM > 0 { + return totals.ProRequests + totals.FlashRequests + } + } + switch modelClass { + case geminiModelFlash: + return totals.FlashRequests + default: + return totals.ProRequests + } +} + func (s *RateLimitService) getGeminiUsageTotals(accountID int64, windowStart, now time.Time) (GeminiUsageTotals, bool) { s.usageCacheMu.RLock() defer s.usageCacheMu.RUnlock() @@ -381,10 +625,31 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head } } - // 2. 尝试从响应头解析重置时间(Anthropic) + // 2. Anthropic 平台:尝试解析 per-window 头(5h / 7d),选择实际触发的窗口 + if result := calculateAnthropic429ResetTime(headers); result != nil { + if err := s.accountRepo.SetRateLimited(ctx, account.ID, result.resetAt); err != nil { + slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) + return + } + + // 更新 session window:优先使用 5h-reset 头精确计算,否则从 resetAt 反推 + windowEnd := result.resetAt + if result.fiveHourReset != nil { + windowEnd = *result.fiveHourReset + } + windowStart := windowEnd.Add(-5 * time.Hour) + if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil { + slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err) + } + + slog.Info("anthropic_account_rate_limited", "account_id", account.ID, "reset_at", result.resetAt, "reset_in", time.Until(result.resetAt).Truncate(time.Second)) + return + } + + // 3. 尝试从响应头解析重置时间(Anthropic 聚合头,向后兼容) resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset") - // 3. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini) + // 4. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini) if resetTimestamp == "" { switch account.Platform { case PlatformOpenAI: @@ -411,7 +676,17 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head } } - // 没有重置时间,使用默认5分钟 + // Anthropic 平台:没有限流重置时间的 429 可能是非真实限流(如 Extra usage required), + // 不标记账号限流状态,直接透传错误给客户端 + if account.Platform == PlatformAnthropic { + slog.Warn("rate_limit_429_no_reset_time_skipped", + "account_id", account.ID, + "platform", account.Platform, + "reason", "no rate limit reset time in headers, likely not a real rate limit") + return + } + + // 其他平台:没有重置时间,使用默认5分钟 resetAt := time.Now().Add(5 * time.Minute) slog.Warn("rate_limit_no_reset_time", "account_id", account.ID, "platform", account.Platform, "using_default", "5m") if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { @@ -497,6 +772,112 @@ func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *tim return nil } +// anthropic429Result holds the parsed Anthropic 429 rate-limit information. +type anthropic429Result struct { + resetAt time.Time // The correct reset time to use for SetRateLimited + fiveHourReset *time.Time // 5h window reset timestamp (for session window calculation), nil if not available +} + +// calculateAnthropic429ResetTime parses Anthropic's per-window rate-limit headers +// to determine which window (5h or 7d) actually triggered the 429. +// +// Headers used: +// - anthropic-ratelimit-unified-5h-utilization / anthropic-ratelimit-unified-5h-surpassed-threshold +// - anthropic-ratelimit-unified-5h-reset +// - anthropic-ratelimit-unified-7d-utilization / anthropic-ratelimit-unified-7d-surpassed-threshold +// - anthropic-ratelimit-unified-7d-reset +// +// Returns nil when the per-window headers are absent (caller should fall back to +// the aggregated anthropic-ratelimit-unified-reset header). +func calculateAnthropic429ResetTime(headers http.Header) *anthropic429Result { + reset5hStr := headers.Get("anthropic-ratelimit-unified-5h-reset") + reset7dStr := headers.Get("anthropic-ratelimit-unified-7d-reset") + + if reset5hStr == "" && reset7dStr == "" { + return nil + } + + var reset5h, reset7d *time.Time + if ts, err := strconv.ParseInt(reset5hStr, 10, 64); err == nil { + t := time.Unix(ts, 0) + reset5h = &t + } + if ts, err := strconv.ParseInt(reset7dStr, 10, 64); err == nil { + t := time.Unix(ts, 0) + reset7d = &t + } + + is5hExceeded := isAnthropicWindowExceeded(headers, "5h") + is7dExceeded := isAnthropicWindowExceeded(headers, "7d") + + slog.Info("anthropic_429_window_analysis", + "is_5h_exceeded", is5hExceeded, + "is_7d_exceeded", is7dExceeded, + "reset_5h", reset5hStr, + "reset_7d", reset7dStr, + ) + + // Select the correct reset time based on which window(s) are exceeded. + var chosen *time.Time + switch { + case is5hExceeded && is7dExceeded: + // Both exceeded → prefer 7d (longer cooldown), fall back to 5h + chosen = reset7d + if chosen == nil { + chosen = reset5h + } + case is5hExceeded: + chosen = reset5h + case is7dExceeded: + chosen = reset7d + default: + // Neither flag clearly exceeded — pick the sooner reset as best guess + chosen = pickSooner(reset5h, reset7d) + } + + if chosen == nil { + return nil + } + return &anthropic429Result{resetAt: *chosen, fiveHourReset: reset5h} +} + +// isAnthropicWindowExceeded checks whether a given Anthropic rate-limit window +// (e.g. "5h" or "7d") has been exceeded, using utilization and surpassed-threshold headers. +func isAnthropicWindowExceeded(headers http.Header, window string) bool { + prefix := "anthropic-ratelimit-unified-" + window + "-" + + // Check surpassed-threshold first (most explicit signal) + if st := headers.Get(prefix + "surpassed-threshold"); strings.EqualFold(st, "true") { + return true + } + + // Fall back to utilization >= 1.0 + if utilStr := headers.Get(prefix + "utilization"); utilStr != "" { + if util, err := strconv.ParseFloat(utilStr, 64); err == nil && util >= 1.0-1e-9 { + // Use a small epsilon to handle floating point: treat 0.9999999... as >= 1.0 + return true + } + } + + return false +} + +// pickSooner returns whichever of the two time pointers is earlier. +// If only one is non-nil, it is returned. If both are nil, returns nil. +func pickSooner(a, b *time.Time) *time.Time { + switch { + case a != nil && b != nil: + if a.Before(*b) { + return a + } + return b + case a != nil: + return a + default: + return b + } +} + // parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳 // OpenAI 的 usage_limit_reached 错误格式: // @@ -611,7 +992,19 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) if err := s.accountRepo.ClearAntigravityQuotaScopes(ctx, accountID); err != nil { return err } - return s.accountRepo.ClearModelRateLimits(ctx, accountID) + if err := s.accountRepo.ClearModelRateLimits(ctx, accountID); err != nil { + return err + } + // 清除限流时一并清理临时不可调度状态,避免周限/窗口重置后仍被本地临时状态阻断。 + if err := s.accountRepo.ClearTempUnschedulable(ctx, accountID); err != nil { + return err + } + if s.tempUnschedCache != nil { + if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil { + slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err) + } + } + return nil } func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error { diff --git a/backend/internal/service/ratelimit_service_401_test.go b/backend/internal/service/ratelimit_service_401_test.go index 36357a4b..7bced46f 100644 --- a/backend/internal/service/ratelimit_service_401_test.go +++ b/backend/internal/service/ratelimit_service_401_test.go @@ -41,7 +41,7 @@ func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, acc return r.err } -func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) { +func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) { tests := []struct { name string platform string @@ -76,9 +76,8 @@ func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) { shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) require.True(t, shouldDisable) - require.Equal(t, 1, repo.setErrorCalls) - require.Equal(t, 0, repo.tempCalls) - require.Contains(t, repo.lastErrorMsg, "Authentication failed (401)") + require.Equal(t, 0, repo.setErrorCalls) + require.Equal(t, 1, repo.tempCalls) require.Len(t, invalidator.accounts, 1) }) } @@ -98,7 +97,8 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) require.True(t, shouldDisable) - require.Equal(t, 1, repo.setErrorCalls) + require.Equal(t, 0, repo.setErrorCalls) + require.Equal(t, 1, repo.tempCalls) require.Len(t, invalidator.accounts, 1) } diff --git a/backend/internal/service/ratelimit_service_anthropic_test.go b/backend/internal/service/ratelimit_service_anthropic_test.go new file mode 100644 index 00000000..eaeaf30e --- /dev/null +++ b/backend/internal/service/ratelimit_service_anthropic_test.go @@ -0,0 +1,202 @@ +package service + +import ( + "net/http" + "testing" + "time" +) + +func TestCalculateAnthropic429ResetTime_Only5hExceeded(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.02") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.32") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1770998400) + + if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) { + t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset) + } +} + +func TestCalculateAnthropic429ResetTime_Only7dExceeded(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.50") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.05") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1771549200) + + // fiveHourReset should still be populated for session window calculation + if result.fiveHourReset == nil || !result.fiveHourReset.Equal(time.Unix(1770998400, 0)) { + t.Errorf("expected fiveHourReset=1770998400, got %v", result.fiveHourReset) + } +} + +func TestCalculateAnthropic429ResetTime_BothExceeded(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.10") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.02") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1771549200) +} + +func TestCalculateAnthropic429ResetTime_NoPerWindowHeaders(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + if result != nil { + t.Errorf("expected nil result when no per-window headers, got resetAt=%v", result.resetAt) + } +} + +func TestCalculateAnthropic429ResetTime_NoHeaders(t *testing.T) { + result := calculateAnthropic429ResetTime(http.Header{}) + if result != nil { + t.Errorf("expected nil result for empty headers, got resetAt=%v", result.resetAt) + } +} + +func TestCalculateAnthropic429ResetTime_SurpassedThreshold(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-surpassed-threshold", "true") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + headers.Set("anthropic-ratelimit-unified-7d-surpassed-threshold", "false") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1770998400) +} + +func TestCalculateAnthropic429ResetTime_UtilizationExactlyOne(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.0") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.5") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1770998400) +} + +func TestCalculateAnthropic429ResetTime_NeitherExceeded_UsesShorter(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "0.95") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") // sooner + headers.Set("anthropic-ratelimit-unified-7d-utilization", "0.80") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") // later + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1770998400) +} + +func TestCalculateAnthropic429ResetTime_Only5hResetHeader(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-5h-utilization", "1.05") + headers.Set("anthropic-ratelimit-unified-5h-reset", "1770998400") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1770998400) +} + +func TestCalculateAnthropic429ResetTime_Only7dResetHeader(t *testing.T) { + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-7d-utilization", "1.03") + headers.Set("anthropic-ratelimit-unified-7d-reset", "1771549200") + + result := calculateAnthropic429ResetTime(headers) + assertAnthropicResult(t, result, 1771549200) + + if result.fiveHourReset != nil { + t.Errorf("expected fiveHourReset=nil when no 5h headers, got %v", result.fiveHourReset) + } +} + +func TestIsAnthropicWindowExceeded(t *testing.T) { + tests := []struct { + name string + headers http.Header + window string + expected bool + }{ + { + name: "utilization above 1.0", + headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.02"), + window: "5h", + expected: true, + }, + { + name: "utilization exactly 1.0", + headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "1.0"), + window: "5h", + expected: true, + }, + { + name: "utilization below 1.0", + headers: makeHeader("anthropic-ratelimit-unified-5h-utilization", "0.99"), + window: "5h", + expected: false, + }, + { + name: "surpassed-threshold true", + headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "true"), + window: "7d", + expected: true, + }, + { + name: "surpassed-threshold True (case insensitive)", + headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "True"), + window: "7d", + expected: true, + }, + { + name: "surpassed-threshold false", + headers: makeHeader("anthropic-ratelimit-unified-7d-surpassed-threshold", "false"), + window: "7d", + expected: false, + }, + { + name: "no headers", + headers: http.Header{}, + window: "5h", + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := isAnthropicWindowExceeded(tc.headers, tc.window) + if got != tc.expected { + t.Errorf("expected %v, got %v", tc.expected, got) + } + }) + } +} + +// assertAnthropicResult is a test helper that verifies the result is non-nil and +// has the expected resetAt unix timestamp. +func assertAnthropicResult(t *testing.T, result *anthropic429Result, wantUnix int64) { + t.Helper() + if result == nil { + t.Fatal("expected non-nil result") + return // unreachable, but satisfies staticcheck SA5011 + } + want := time.Unix(wantUnix, 0) + if !result.resetAt.Equal(want) { + t.Errorf("expected resetAt=%v, got %v", want, result.resetAt) + } +} + +func makeHeader(key, value string) http.Header { + h := http.Header{} + h.Set(key, value) + return h +} diff --git a/backend/internal/service/ratelimit_service_clear_test.go b/backend/internal/service/ratelimit_service_clear_test.go new file mode 100644 index 00000000..f48151ed --- /dev/null +++ b/backend/internal/service/ratelimit_service_clear_test.go @@ -0,0 +1,172 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type rateLimitClearRepoStub struct { + mockAccountRepoForGemini + clearRateLimitCalls int + clearAntigravityCalls int + clearModelRateLimitCalls int + clearTempUnschedCalls int + clearRateLimitErr error + clearAntigravityErr error + clearModelRateLimitErr error + clearTempUnschedulableErr error +} + +func (r *rateLimitClearRepoStub) ClearRateLimit(ctx context.Context, id int64) error { + r.clearRateLimitCalls++ + return r.clearRateLimitErr +} + +func (r *rateLimitClearRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error { + r.clearAntigravityCalls++ + return r.clearAntigravityErr +} + +func (r *rateLimitClearRepoStub) ClearModelRateLimits(ctx context.Context, id int64) error { + r.clearModelRateLimitCalls++ + return r.clearModelRateLimitErr +} + +func (r *rateLimitClearRepoStub) ClearTempUnschedulable(ctx context.Context, id int64) error { + r.clearTempUnschedCalls++ + return r.clearTempUnschedulableErr +} + +type tempUnschedCacheRecorder struct { + deletedIDs []int64 + deleteErr error +} + +func (c *tempUnschedCacheRecorder) SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error { + return nil +} + +func (c *tempUnschedCacheRecorder) GetTempUnsched(ctx context.Context, accountID int64) (*TempUnschedState, error) { + return nil, nil +} + +func (c *tempUnschedCacheRecorder) DeleteTempUnsched(ctx context.Context, accountID int64) error { + c.deletedIDs = append(c.deletedIDs, accountID) + return c.deleteErr +} + +func TestRateLimitService_ClearRateLimit_AlsoClearsTempUnschedulable(t *testing.T) { + repo := &rateLimitClearRepoStub{} + cache := &tempUnschedCacheRecorder{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + err := svc.ClearRateLimit(context.Background(), 42) + require.NoError(t, err) + + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 1, repo.clearAntigravityCalls) + require.Equal(t, 1, repo.clearModelRateLimitCalls) + require.Equal(t, 1, repo.clearTempUnschedCalls) + require.Equal(t, []int64{42}, cache.deletedIDs) +} + +func TestRateLimitService_ClearRateLimit_ClearTempUnschedulableFailed(t *testing.T) { + repo := &rateLimitClearRepoStub{ + clearTempUnschedulableErr: errors.New("clear temp unsched failed"), + } + cache := &tempUnschedCacheRecorder{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + err := svc.ClearRateLimit(context.Background(), 7) + require.Error(t, err) + + require.Equal(t, 1, repo.clearTempUnschedCalls) + require.Empty(t, cache.deletedIDs) +} + +func TestRateLimitService_ClearRateLimit_ClearRateLimitFailed(t *testing.T) { + repo := &rateLimitClearRepoStub{ + clearRateLimitErr: errors.New("clear rate limit failed"), + } + cache := &tempUnschedCacheRecorder{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + err := svc.ClearRateLimit(context.Background(), 11) + require.Error(t, err) + + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 0, repo.clearAntigravityCalls) + require.Equal(t, 0, repo.clearModelRateLimitCalls) + require.Equal(t, 0, repo.clearTempUnschedCalls) + require.Empty(t, cache.deletedIDs) +} + +func TestRateLimitService_ClearRateLimit_ClearAntigravityFailed(t *testing.T) { + repo := &rateLimitClearRepoStub{ + clearAntigravityErr: errors.New("clear antigravity failed"), + } + cache := &tempUnschedCacheRecorder{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + err := svc.ClearRateLimit(context.Background(), 12) + require.Error(t, err) + + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 1, repo.clearAntigravityCalls) + require.Equal(t, 0, repo.clearModelRateLimitCalls) + require.Equal(t, 0, repo.clearTempUnschedCalls) + require.Empty(t, cache.deletedIDs) +} + +func TestRateLimitService_ClearRateLimit_ClearModelRateLimitsFailed(t *testing.T) { + repo := &rateLimitClearRepoStub{ + clearModelRateLimitErr: errors.New("clear model rate limits failed"), + } + cache := &tempUnschedCacheRecorder{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + err := svc.ClearRateLimit(context.Background(), 13) + require.Error(t, err) + + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 1, repo.clearAntigravityCalls) + require.Equal(t, 1, repo.clearModelRateLimitCalls) + require.Equal(t, 0, repo.clearTempUnschedCalls) + require.Empty(t, cache.deletedIDs) +} + +func TestRateLimitService_ClearRateLimit_CacheDeleteFailedShouldNotFail(t *testing.T) { + repo := &rateLimitClearRepoStub{} + cache := &tempUnschedCacheRecorder{ + deleteErr: errors.New("cache delete failed"), + } + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + err := svc.ClearRateLimit(context.Background(), 14) + require.NoError(t, err) + + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 1, repo.clearAntigravityCalls) + require.Equal(t, 1, repo.clearModelRateLimitCalls) + require.Equal(t, 1, repo.clearTempUnschedCalls) + require.Equal(t, []int64{14}, cache.deletedIDs) +} + +func TestRateLimitService_ClearRateLimit_WithoutTempUnschedCache(t *testing.T) { + repo := &rateLimitClearRepoStub{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + err := svc.ClearRateLimit(context.Background(), 15) + require.NoError(t, err) + + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 1, repo.clearAntigravityCalls) + require.Equal(t, 1, repo.clearModelRateLimitCalls) + require.Equal(t, 1, repo.clearTempUnschedCalls) +} diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index ad277ca0..b22da752 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -174,6 +174,33 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ return codes, nil } +// CreateCode creates a redeem code with caller-provided code value. +// It is primarily used by admin integrations that require an external order ID +// to be mapped to a deterministic redeem code. +func (s *RedeemService) CreateCode(ctx context.Context, code *RedeemCode) error { + if code == nil { + return errors.New("redeem code is required") + } + code.Code = strings.TrimSpace(code.Code) + if code.Code == "" { + return errors.New("code is required") + } + if code.Type == "" { + code.Type = RedeemTypeBalance + } + if code.Type != RedeemTypeInvitation && code.Value <= 0 { + return errors.New("value must be greater than 0") + } + if code.Status == "" { + code.Status = StatusUnused + } + + if err := s.redeemRepo.Create(ctx, code); err != nil { + return fmt.Errorf("create redeem code: %w", err) + } + return nil +} + // checkRedeemRateLimit 检查用户兑换错误次数是否超限 func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error { if s.cache == nil { diff --git a/backend/internal/service/registration_email_policy.go b/backend/internal/service/registration_email_policy.go new file mode 100644 index 00000000..875668c7 --- /dev/null +++ b/backend/internal/service/registration_email_policy.go @@ -0,0 +1,123 @@ +package service + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" +) + +var registrationEmailDomainPattern = regexp.MustCompile( + `^[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?(?:\.[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?)+$`, +) + +// RegistrationEmailSuffix extracts normalized suffix in "@domain" form. +func RegistrationEmailSuffix(email string) string { + _, domain, ok := splitEmailForPolicy(email) + if !ok { + return "" + } + return "@" + domain +} + +// IsRegistrationEmailSuffixAllowed checks whether an email is allowed by suffix whitelist. +// Empty whitelist means allow all. +func IsRegistrationEmailSuffixAllowed(email string, whitelist []string) bool { + if len(whitelist) == 0 { + return true + } + suffix := RegistrationEmailSuffix(email) + if suffix == "" { + return false + } + for _, allowed := range whitelist { + if suffix == allowed { + return true + } + } + return false +} + +// NormalizeRegistrationEmailSuffixWhitelist normalizes and validates suffix whitelist items. +func NormalizeRegistrationEmailSuffixWhitelist(raw []string) ([]string, error) { + return normalizeRegistrationEmailSuffixWhitelist(raw, true) +} + +// ParseRegistrationEmailSuffixWhitelist parses persisted JSON into normalized suffixes. +// Invalid entries are ignored to keep old misconfigurations from breaking runtime reads. +func ParseRegistrationEmailSuffixWhitelist(raw string) []string { + raw = strings.TrimSpace(raw) + if raw == "" { + return []string{} + } + var items []string + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return []string{} + } + normalized, _ := normalizeRegistrationEmailSuffixWhitelist(items, false) + if len(normalized) == 0 { + return []string{} + } + return normalized +} + +func normalizeRegistrationEmailSuffixWhitelist(raw []string, strict bool) ([]string, error) { + if len(raw) == 0 { + return nil, nil + } + + seen := make(map[string]struct{}, len(raw)) + out := make([]string, 0, len(raw)) + for _, item := range raw { + normalized, err := normalizeRegistrationEmailSuffix(item) + if err != nil { + if strict { + return nil, err + } + continue + } + if normalized == "" { + continue + } + if _, ok := seen[normalized]; ok { + continue + } + seen[normalized] = struct{}{} + out = append(out, normalized) + } + + if len(out) == 0 { + return nil, nil + } + return out, nil +} + +func normalizeRegistrationEmailSuffix(raw string) (string, error) { + value := strings.ToLower(strings.TrimSpace(raw)) + if value == "" { + return "", nil + } + + domain := value + if strings.Contains(value, "@") { + if !strings.HasPrefix(value, "@") || strings.Count(value, "@") != 1 { + return "", fmt.Errorf("invalid email suffix: %q", raw) + } + domain = strings.TrimPrefix(value, "@") + } + + if domain == "" || strings.Contains(domain, "@") || !registrationEmailDomainPattern.MatchString(domain) { + return "", fmt.Errorf("invalid email suffix: %q", raw) + } + + return "@" + domain, nil +} + +func splitEmailForPolicy(raw string) (local string, domain string, ok bool) { + email := strings.ToLower(strings.TrimSpace(raw)) + local, domain, found := strings.Cut(email, "@") + if !found || local == "" || domain == "" || strings.Contains(domain, "@") { + return "", "", false + } + return local, domain, true +} diff --git a/backend/internal/service/registration_email_policy_test.go b/backend/internal/service/registration_email_policy_test.go new file mode 100644 index 00000000..f0c46642 --- /dev/null +++ b/backend/internal/service/registration_email_policy_test.go @@ -0,0 +1,31 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNormalizeRegistrationEmailSuffixWhitelist(t *testing.T) { + got, err := NormalizeRegistrationEmailSuffixWhitelist([]string{"example.com", "@EXAMPLE.COM", " @foo.bar "}) + require.NoError(t, err) + require.Equal(t, []string{"@example.com", "@foo.bar"}, got) +} + +func TestNormalizeRegistrationEmailSuffixWhitelist_Invalid(t *testing.T) { + _, err := NormalizeRegistrationEmailSuffixWhitelist([]string{"@invalid_domain"}) + require.Error(t, err) +} + +func TestParseRegistrationEmailSuffixWhitelist(t *testing.T) { + got := ParseRegistrationEmailSuffixWhitelist(`["example.com","@foo.bar","@invalid_domain"]`) + require.Equal(t, []string{"@example.com", "@foo.bar"}, got) +} + +func TestIsRegistrationEmailSuffixAllowed(t *testing.T) { + require.True(t, IsRegistrationEmailSuffixAllowed("user@example.com", []string{"@example.com"})) + require.False(t, IsRegistrationEmailSuffixAllowed("user@sub.example.com", []string{"@example.com"})) + require.True(t, IsRegistrationEmailSuffixAllowed("user@any.com", []string{})) +} diff --git a/backend/internal/service/request_metadata.go b/backend/internal/service/request_metadata.go new file mode 100644 index 00000000..5c81bbf1 --- /dev/null +++ b/backend/internal/service/request_metadata.go @@ -0,0 +1,216 @@ +package service + +import ( + "context" + "sync/atomic" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" +) + +type requestMetadataContextKey struct{} + +var requestMetadataKey = requestMetadataContextKey{} + +type RequestMetadata struct { + IsMaxTokensOneHaikuRequest *bool + ThinkingEnabled *bool + PrefetchedStickyAccountID *int64 + PrefetchedStickyGroupID *int64 + SingleAccountRetry *bool + AccountSwitchCount *int +} + +var ( + requestMetadataFallbackIsMaxTokensOneHaikuTotal atomic.Int64 + requestMetadataFallbackThinkingEnabledTotal atomic.Int64 + requestMetadataFallbackPrefetchedStickyAccount atomic.Int64 + requestMetadataFallbackPrefetchedStickyGroup atomic.Int64 + requestMetadataFallbackSingleAccountRetryTotal atomic.Int64 + requestMetadataFallbackAccountSwitchCountTotal atomic.Int64 +) + +func RequestMetadataFallbackStats() (isMaxTokensOneHaiku, thinkingEnabled, prefetchedStickyAccount, prefetchedStickyGroup, singleAccountRetry, accountSwitchCount int64) { + return requestMetadataFallbackIsMaxTokensOneHaikuTotal.Load(), + requestMetadataFallbackThinkingEnabledTotal.Load(), + requestMetadataFallbackPrefetchedStickyAccount.Load(), + requestMetadataFallbackPrefetchedStickyGroup.Load(), + requestMetadataFallbackSingleAccountRetryTotal.Load(), + requestMetadataFallbackAccountSwitchCountTotal.Load() +} + +func metadataFromContext(ctx context.Context) *RequestMetadata { + if ctx == nil { + return nil + } + md, _ := ctx.Value(requestMetadataKey).(*RequestMetadata) + return md +} + +func updateRequestMetadata( + ctx context.Context, + bridgeOldKeys bool, + update func(md *RequestMetadata), + legacyBridge func(ctx context.Context) context.Context, +) context.Context { + if ctx == nil { + return nil + } + current := metadataFromContext(ctx) + next := &RequestMetadata{} + if current != nil { + *next = *current + } + update(next) + ctx = context.WithValue(ctx, requestMetadataKey, next) + if bridgeOldKeys && legacyBridge != nil { + ctx = legacyBridge(ctx) + } + return ctx +} + +func WithIsMaxTokensOneHaikuRequest(ctx context.Context, value bool, bridgeOldKeys bool) context.Context { + return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) { + v := value + md.IsMaxTokensOneHaikuRequest = &v + }, func(base context.Context) context.Context { + return context.WithValue(base, ctxkey.IsMaxTokensOneHaikuRequest, value) + }) +} + +func WithThinkingEnabled(ctx context.Context, value bool, bridgeOldKeys bool) context.Context { + return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) { + v := value + md.ThinkingEnabled = &v + }, func(base context.Context) context.Context { + return context.WithValue(base, ctxkey.ThinkingEnabled, value) + }) +} + +func WithPrefetchedStickySession(ctx context.Context, accountID, groupID int64, bridgeOldKeys bool) context.Context { + return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) { + account := accountID + group := groupID + md.PrefetchedStickyAccountID = &account + md.PrefetchedStickyGroupID = &group + }, func(base context.Context) context.Context { + bridged := context.WithValue(base, ctxkey.PrefetchedStickyAccountID, accountID) + return context.WithValue(bridged, ctxkey.PrefetchedStickyGroupID, groupID) + }) +} + +func WithSingleAccountRetry(ctx context.Context, value bool, bridgeOldKeys bool) context.Context { + return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) { + v := value + md.SingleAccountRetry = &v + }, func(base context.Context) context.Context { + return context.WithValue(base, ctxkey.SingleAccountRetry, value) + }) +} + +func WithAccountSwitchCount(ctx context.Context, value int, bridgeOldKeys bool) context.Context { + return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) { + v := value + md.AccountSwitchCount = &v + }, func(base context.Context) context.Context { + return context.WithValue(base, ctxkey.AccountSwitchCount, value) + }) +} + +func IsMaxTokensOneHaikuRequestFromContext(ctx context.Context) (bool, bool) { + if md := metadataFromContext(ctx); md != nil && md.IsMaxTokensOneHaikuRequest != nil { + return *md.IsMaxTokensOneHaikuRequest, true + } + if ctx == nil { + return false, false + } + if value, ok := ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok { + requestMetadataFallbackIsMaxTokensOneHaikuTotal.Add(1) + return value, true + } + return false, false +} + +func ThinkingEnabledFromContext(ctx context.Context) (bool, bool) { + if md := metadataFromContext(ctx); md != nil && md.ThinkingEnabled != nil { + return *md.ThinkingEnabled, true + } + if ctx == nil { + return false, false + } + if value, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok { + requestMetadataFallbackThinkingEnabledTotal.Add(1) + return value, true + } + return false, false +} + +func PrefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) { + if md := metadataFromContext(ctx); md != nil && md.PrefetchedStickyGroupID != nil { + return *md.PrefetchedStickyGroupID, true + } + if ctx == nil { + return 0, false + } + v := ctx.Value(ctxkey.PrefetchedStickyGroupID) + switch t := v.(type) { + case int64: + requestMetadataFallbackPrefetchedStickyGroup.Add(1) + return t, true + case int: + requestMetadataFallbackPrefetchedStickyGroup.Add(1) + return int64(t), true + } + return 0, false +} + +func PrefetchedStickyAccountIDFromContext(ctx context.Context) (int64, bool) { + if md := metadataFromContext(ctx); md != nil && md.PrefetchedStickyAccountID != nil { + return *md.PrefetchedStickyAccountID, true + } + if ctx == nil { + return 0, false + } + v := ctx.Value(ctxkey.PrefetchedStickyAccountID) + switch t := v.(type) { + case int64: + requestMetadataFallbackPrefetchedStickyAccount.Add(1) + return t, true + case int: + requestMetadataFallbackPrefetchedStickyAccount.Add(1) + return int64(t), true + } + return 0, false +} + +func SingleAccountRetryFromContext(ctx context.Context) (bool, bool) { + if md := metadataFromContext(ctx); md != nil && md.SingleAccountRetry != nil { + return *md.SingleAccountRetry, true + } + if ctx == nil { + return false, false + } + if value, ok := ctx.Value(ctxkey.SingleAccountRetry).(bool); ok { + requestMetadataFallbackSingleAccountRetryTotal.Add(1) + return value, true + } + return false, false +} + +func AccountSwitchCountFromContext(ctx context.Context) (int, bool) { + if md := metadataFromContext(ctx); md != nil && md.AccountSwitchCount != nil { + return *md.AccountSwitchCount, true + } + if ctx == nil { + return 0, false + } + v := ctx.Value(ctxkey.AccountSwitchCount) + switch t := v.(type) { + case int: + requestMetadataFallbackAccountSwitchCountTotal.Add(1) + return t, true + case int64: + requestMetadataFallbackAccountSwitchCountTotal.Add(1) + return int(t), true + } + return 0, false +} diff --git a/backend/internal/service/request_metadata_test.go b/backend/internal/service/request_metadata_test.go new file mode 100644 index 00000000..7d192699 --- /dev/null +++ b/backend/internal/service/request_metadata_test.go @@ -0,0 +1,119 @@ +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func TestRequestMetadataWriteAndRead_NoBridge(t *testing.T) { + ctx := context.Background() + ctx = WithIsMaxTokensOneHaikuRequest(ctx, true, false) + ctx = WithThinkingEnabled(ctx, true, false) + ctx = WithPrefetchedStickySession(ctx, 123, 456, false) + ctx = WithSingleAccountRetry(ctx, true, false) + ctx = WithAccountSwitchCount(ctx, 2, false) + + isHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(ctx) + require.True(t, ok) + require.True(t, isHaiku) + + thinking, ok := ThinkingEnabledFromContext(ctx) + require.True(t, ok) + require.True(t, thinking) + + accountID, ok := PrefetchedStickyAccountIDFromContext(ctx) + require.True(t, ok) + require.Equal(t, int64(123), accountID) + + groupID, ok := PrefetchedStickyGroupIDFromContext(ctx) + require.True(t, ok) + require.Equal(t, int64(456), groupID) + + singleRetry, ok := SingleAccountRetryFromContext(ctx) + require.True(t, ok) + require.True(t, singleRetry) + + switchCount, ok := AccountSwitchCountFromContext(ctx) + require.True(t, ok) + require.Equal(t, 2, switchCount) + + require.Nil(t, ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest)) + require.Nil(t, ctx.Value(ctxkey.ThinkingEnabled)) + require.Nil(t, ctx.Value(ctxkey.PrefetchedStickyAccountID)) + require.Nil(t, ctx.Value(ctxkey.PrefetchedStickyGroupID)) + require.Nil(t, ctx.Value(ctxkey.SingleAccountRetry)) + require.Nil(t, ctx.Value(ctxkey.AccountSwitchCount)) +} + +func TestRequestMetadataWrite_BridgeLegacyKeys(t *testing.T) { + ctx := context.Background() + ctx = WithIsMaxTokensOneHaikuRequest(ctx, true, true) + ctx = WithThinkingEnabled(ctx, true, true) + ctx = WithPrefetchedStickySession(ctx, 123, 456, true) + ctx = WithSingleAccountRetry(ctx, true, true) + ctx = WithAccountSwitchCount(ctx, 2, true) + + require.Equal(t, true, ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest)) + require.Equal(t, true, ctx.Value(ctxkey.ThinkingEnabled)) + require.Equal(t, int64(123), ctx.Value(ctxkey.PrefetchedStickyAccountID)) + require.Equal(t, int64(456), ctx.Value(ctxkey.PrefetchedStickyGroupID)) + require.Equal(t, true, ctx.Value(ctxkey.SingleAccountRetry)) + require.Equal(t, 2, ctx.Value(ctxkey.AccountSwitchCount)) +} + +func TestRequestMetadataRead_LegacyFallbackAndStats(t *testing.T) { + beforeHaiku, beforeThinking, beforeAccount, beforeGroup, beforeSingleRetry, beforeSwitchCount := RequestMetadataFallbackStats() + + ctx := context.Background() + ctx = context.WithValue(ctx, ctxkey.IsMaxTokensOneHaikuRequest, true) + ctx = context.WithValue(ctx, ctxkey.ThinkingEnabled, true) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyAccountID, int64(321)) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(654)) + ctx = context.WithValue(ctx, ctxkey.SingleAccountRetry, true) + ctx = context.WithValue(ctx, ctxkey.AccountSwitchCount, int64(3)) + + isHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(ctx) + require.True(t, ok) + require.True(t, isHaiku) + + thinking, ok := ThinkingEnabledFromContext(ctx) + require.True(t, ok) + require.True(t, thinking) + + accountID, ok := PrefetchedStickyAccountIDFromContext(ctx) + require.True(t, ok) + require.Equal(t, int64(321), accountID) + + groupID, ok := PrefetchedStickyGroupIDFromContext(ctx) + require.True(t, ok) + require.Equal(t, int64(654), groupID) + + singleRetry, ok := SingleAccountRetryFromContext(ctx) + require.True(t, ok) + require.True(t, singleRetry) + + switchCount, ok := AccountSwitchCountFromContext(ctx) + require.True(t, ok) + require.Equal(t, 3, switchCount) + + afterHaiku, afterThinking, afterAccount, afterGroup, afterSingleRetry, afterSwitchCount := RequestMetadataFallbackStats() + require.Equal(t, beforeHaiku+1, afterHaiku) + require.Equal(t, beforeThinking+1, afterThinking) + require.Equal(t, beforeAccount+1, afterAccount) + require.Equal(t, beforeGroup+1, afterGroup) + require.Equal(t, beforeSingleRetry+1, afterSingleRetry) + require.Equal(t, beforeSwitchCount+1, afterSwitchCount) +} + +func TestRequestMetadataRead_PreferMetadataOverLegacy(t *testing.T) { + ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, false) + ctx = WithThinkingEnabled(ctx, true, false) + + thinking, ok := ThinkingEnabledFromContext(ctx) + require.True(t, ok) + require.True(t, thinking) + require.Equal(t, false, ctx.Value(ctxkey.ThinkingEnabled)) +} diff --git a/backend/internal/service/response_header_filter.go b/backend/internal/service/response_header_filter.go new file mode 100644 index 00000000..81012b01 --- /dev/null +++ b/backend/internal/service/response_header_filter.go @@ -0,0 +1,13 @@ +package service + +import ( + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" +) + +func compileResponseHeaderFilter(cfg *config.Config) *responseheaders.CompiledHeaderFilter { + if cfg == nil { + return nil + } + return responseheaders.CompileHeaderFilter(cfg.Security.ResponseHeaders) +} diff --git a/backend/internal/service/rpm_cache.go b/backend/internal/service/rpm_cache.go new file mode 100644 index 00000000..07036219 --- /dev/null +++ b/backend/internal/service/rpm_cache.go @@ -0,0 +1,17 @@ +package service + +import "context" + +// RPMCache RPM 计数器缓存接口 +// 用于 Anthropic OAuth/SetupToken 账号的每分钟请求数限制 +type RPMCache interface { + // IncrementRPM 原子递增并返回当前分钟的计数 + // 使用 Redis 服务器时间确定 minute key,避免多实例时钟偏差 + IncrementRPM(ctx context.Context, accountID int64) (count int, err error) + + // GetRPM 获取当前分钟的 RPM 计数 + GetRPM(ctx context.Context, accountID int64) (count int, err error) + + // GetRPMBatch 批量获取多个账号的 RPM 计数(使用 Pipeline) + GetRPMBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) +} diff --git a/backend/internal/service/scheduler_shuffle_test.go b/backend/internal/service/scheduler_shuffle_test.go index 78ac5f57..0d82b2f3 100644 --- a/backend/internal/service/scheduler_shuffle_test.go +++ b/backend/internal/service/scheduler_shuffle_test.go @@ -125,13 +125,13 @@ func TestShuffleWithinSortGroups_MixedGroups(t *testing.T) { // ============ shuffleWithinPriorityAndLastUsed 测试 ============ func TestShuffleWithinPriorityAndLastUsed_Empty(t *testing.T) { - shuffleWithinPriorityAndLastUsed(nil) - shuffleWithinPriorityAndLastUsed([]*Account{}) + shuffleWithinPriorityAndLastUsed(nil, false) + shuffleWithinPriorityAndLastUsed([]*Account{}, false) } func TestShuffleWithinPriorityAndLastUsed_SingleElement(t *testing.T) { accounts := []*Account{{ID: 1, Priority: 1}} - shuffleWithinPriorityAndLastUsed(accounts) + shuffleWithinPriorityAndLastUsed(accounts, false) require.Equal(t, int64(1), accounts[0].ID) } @@ -146,7 +146,7 @@ func TestShuffleWithinPriorityAndLastUsed_SameGroup_Shuffled(t *testing.T) { for i := 0; i < 100; i++ { cpy := make([]*Account, len(accounts)) copy(cpy, accounts) - shuffleWithinPriorityAndLastUsed(cpy) + shuffleWithinPriorityAndLastUsed(cpy, false) seen[cpy[0].ID] = true } require.GreaterOrEqual(t, len(seen), 2, "same group should be shuffled") @@ -162,7 +162,7 @@ func TestShuffleWithinPriorityAndLastUsed_DifferentPriority_OrderPreserved(t *te for i := 0; i < 20; i++ { cpy := make([]*Account, len(accounts)) copy(cpy, accounts) - shuffleWithinPriorityAndLastUsed(cpy) + shuffleWithinPriorityAndLastUsed(cpy, false) require.Equal(t, int64(1), cpy[0].ID) require.Equal(t, int64(2), cpy[1].ID) require.Equal(t, int64(3), cpy[2].ID) @@ -182,7 +182,7 @@ func TestShuffleWithinPriorityAndLastUsed_DifferentLastUsedAt_OrderPreserved(t * for i := 0; i < 20; i++ { cpy := make([]*Account, len(accounts)) copy(cpy, accounts) - shuffleWithinPriorityAndLastUsed(cpy) + shuffleWithinPriorityAndLastUsed(cpy, false) require.Equal(t, int64(1), cpy[0].ID) require.Equal(t, int64(2), cpy[1].ID) require.Equal(t, int64(3), cpy[2].ID) diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go index 52d455b8..4c9540f1 100644 --- a/backend/internal/service/scheduler_snapshot_service.go +++ b/backend/internal/service/scheduler_snapshot_service.go @@ -4,12 +4,13 @@ import ( "context" "encoding/json" "errors" - "log" + "log/slog" "strconv" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) var ( @@ -103,7 +104,7 @@ func (s *SchedulerSnapshotService) ListSchedulableAccounts(ctx context.Context, if s.cache != nil { cached, hit, err := s.cache.GetSnapshot(ctx, bucket) if err != nil { - log.Printf("[Scheduler] cache read failed: bucket=%s err=%v", bucket.String(), err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] cache read failed: bucket=%s err=%v", bucket.String(), err) } else if hit { return derefAccounts(cached), useMixed, nil } @@ -123,7 +124,7 @@ func (s *SchedulerSnapshotService) ListSchedulableAccounts(ctx context.Context, if s.cache != nil { if err := s.cache.SetSnapshot(fallbackCtx, bucket, accounts); err != nil { - log.Printf("[Scheduler] cache write failed: bucket=%s err=%v", bucket.String(), err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] cache write failed: bucket=%s err=%v", bucket.String(), err) } } @@ -137,7 +138,7 @@ func (s *SchedulerSnapshotService) GetAccount(ctx context.Context, accountID int if s.cache != nil { account, err := s.cache.GetAccount(ctx, accountID) if err != nil { - log.Printf("[Scheduler] account cache read failed: id=%d err=%v", accountID, err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] account cache read failed: id=%d err=%v", accountID, err) } else if account != nil { return account, nil } @@ -167,17 +168,17 @@ func (s *SchedulerSnapshotService) runInitialRebuild() { defer cancel() buckets, err := s.cache.ListBuckets(ctx) if err != nil { - log.Printf("[Scheduler] list buckets failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] list buckets failed: %v", err) } if len(buckets) == 0 { buckets, err = s.defaultBuckets(ctx) if err != nil { - log.Printf("[Scheduler] default buckets failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] default buckets failed: %v", err) return } } if err := s.rebuildBuckets(ctx, buckets, "startup"); err != nil { - log.Printf("[Scheduler] rebuild startup failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] rebuild startup failed: %v", err) } } @@ -204,7 +205,7 @@ func (s *SchedulerSnapshotService) runFullRebuildWorker(interval time.Duration) select { case <-ticker.C: if err := s.triggerFullRebuild("interval"); err != nil { - log.Printf("[Scheduler] full rebuild failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] full rebuild failed: %v", err) } case <-s.stopCh: return @@ -221,13 +222,13 @@ func (s *SchedulerSnapshotService) pollOutbox() { watermark, err := s.cache.GetOutboxWatermark(ctx) if err != nil { - log.Printf("[Scheduler] outbox watermark read failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox watermark read failed: %v", err) return } events, err := s.outboxRepo.ListAfter(ctx, watermark, 200) if err != nil { - log.Printf("[Scheduler] outbox poll failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox poll failed: %v", err) return } if len(events) == 0 { @@ -240,14 +241,14 @@ func (s *SchedulerSnapshotService) pollOutbox() { err := s.handleOutboxEvent(eventCtx, event) cancel() if err != nil { - log.Printf("[Scheduler] outbox handle failed: id=%d type=%s err=%v", event.ID, event.EventType, err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox handle failed: id=%d type=%s err=%v", event.ID, event.EventType, err) return } } lastID := events[len(events)-1].ID if err := s.cache.SetOutboxWatermark(ctx, lastID); err != nil { - log.Printf("[Scheduler] outbox watermark write failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox watermark write failed: %v", err) } else { watermarkForCheck = lastID } @@ -304,13 +305,78 @@ func (s *SchedulerSnapshotService) handleBulkAccountEvent(ctx context.Context, p if payload == nil { return nil } - ids := parseInt64Slice(payload["account_ids"]) - for _, id := range ids { - if err := s.handleAccountEvent(ctx, &id, payload); err != nil { - return err + if s.accountRepo == nil { + return nil + } + + rawIDs := parseInt64Slice(payload["account_ids"]) + if len(rawIDs) == 0 { + return nil + } + + ids := make([]int64, 0, len(rawIDs)) + seen := make(map[int64]struct{}, len(rawIDs)) + for _, id := range rawIDs { + if id <= 0 { + continue + } + if _, exists := seen[id]; exists { + continue + } + seen[id] = struct{}{} + ids = append(ids, id) + } + if len(ids) == 0 { + return nil + } + + preloadGroupIDs := parseInt64Slice(payload["group_ids"]) + accounts, err := s.accountRepo.GetByIDs(ctx, ids) + if err != nil { + return err + } + + found := make(map[int64]struct{}, len(accounts)) + rebuildGroupSet := make(map[int64]struct{}, len(preloadGroupIDs)) + for _, gid := range preloadGroupIDs { + if gid > 0 { + rebuildGroupSet[gid] = struct{}{} } } - return nil + + for _, account := range accounts { + if account == nil || account.ID <= 0 { + continue + } + found[account.ID] = struct{}{} + if s.cache != nil { + if err := s.cache.SetAccount(ctx, account); err != nil { + return err + } + } + for _, gid := range account.GroupIDs { + if gid > 0 { + rebuildGroupSet[gid] = struct{}{} + } + } + } + + if s.cache != nil { + for _, id := range ids { + if _, ok := found[id]; ok { + continue + } + if err := s.cache.DeleteAccount(ctx, id); err != nil { + return err + } + } + } + + rebuildGroupIDs := make([]int64, 0, len(rebuildGroupSet)) + for gid := range rebuildGroupSet { + rebuildGroupIDs = append(rebuildGroupIDs, gid) + } + return s.rebuildByGroupIDs(ctx, rebuildGroupIDs, "account_bulk_change") } func (s *SchedulerSnapshotService) handleAccountEvent(ctx context.Context, accountID *int64, payload map[string]any) error { @@ -444,14 +510,14 @@ func (s *SchedulerSnapshotService) rebuildBucket(ctx context.Context, bucket Sch accounts, err := s.loadAccountsFromDB(rebuildCtx, bucket, bucket.Mode == SchedulerModeMixed) if err != nil { - log.Printf("[Scheduler] rebuild failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] rebuild failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err) return err } if err := s.cache.SetSnapshot(rebuildCtx, bucket, accounts); err != nil { - log.Printf("[Scheduler] rebuild cache failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] rebuild cache failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err) return err } - log.Printf("[Scheduler] rebuild ok: bucket=%s reason=%s size=%d", bucket.String(), reason, len(accounts)) + slog.Debug("[Scheduler] rebuild ok", "bucket", bucket.String(), "reason", reason, "size", len(accounts)) return nil } @@ -464,13 +530,13 @@ func (s *SchedulerSnapshotService) triggerFullRebuild(reason string) error { buckets, err := s.cache.ListBuckets(ctx) if err != nil { - log.Printf("[Scheduler] list buckets failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] list buckets failed: %v", err) return err } if len(buckets) == 0 { buckets, err = s.defaultBuckets(ctx) if err != nil { - log.Printf("[Scheduler] default buckets failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] default buckets failed: %v", err) return err } } @@ -484,7 +550,7 @@ func (s *SchedulerSnapshotService) checkOutboxLag(ctx context.Context, oldest Sc lag := time.Since(oldest.CreatedAt) if lagSeconds := int(lag.Seconds()); lagSeconds >= s.cfg.Gateway.Scheduling.OutboxLagWarnSeconds && s.cfg.Gateway.Scheduling.OutboxLagWarnSeconds > 0 { - log.Printf("[Scheduler] outbox lag warning: %ds", lagSeconds) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox lag warning: %ds", lagSeconds) } if s.cfg.Gateway.Scheduling.OutboxLagRebuildSeconds > 0 && int(lag.Seconds()) >= s.cfg.Gateway.Scheduling.OutboxLagRebuildSeconds { @@ -494,12 +560,12 @@ func (s *SchedulerSnapshotService) checkOutboxLag(ctx context.Context, oldest Sc s.lagMu.Unlock() if failures >= s.cfg.Gateway.Scheduling.OutboxLagRebuildFailures { - log.Printf("[Scheduler] outbox lag rebuild triggered: lag=%s failures=%d", lag, failures) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox lag rebuild triggered: lag=%s failures=%d", lag, failures) s.lagMu.Lock() s.lagFailures = 0 s.lagMu.Unlock() if err := s.triggerFullRebuild("outbox_lag"); err != nil { - log.Printf("[Scheduler] outbox lag rebuild failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox lag rebuild failed: %v", err) } } } else { @@ -517,9 +583,9 @@ func (s *SchedulerSnapshotService) checkOutboxLag(ctx context.Context, oldest Sc return } if maxID-watermark >= int64(threshold) { - log.Printf("[Scheduler] outbox backlog rebuild triggered: backlog=%d", maxID-watermark) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox backlog rebuild triggered: backlog=%d", maxID-watermark) if err := s.triggerFullRebuild("outbox_backlog"); err != nil { - log.Printf("[Scheduler] outbox backlog rebuild failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox backlog rebuild failed: %v", err) } } } @@ -539,8 +605,10 @@ func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucke var err error if groupID > 0 { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, groupID, platforms) - } else { + } else if s.isRunModeSimple() { accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) + } else { + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, platforms) } if err != nil { return nil, err @@ -558,7 +626,10 @@ func (s *SchedulerSnapshotService) loadAccountsFromDB(ctx context.Context, bucke if groupID > 0 { return s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, groupID, bucket.Platform) } - return s.accountRepo.ListSchedulableByPlatform(ctx, bucket.Platform) + if s.isRunModeSimple() { + return s.accountRepo.ListSchedulableByPlatform(ctx, bucket.Platform) + } + return s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, bucket.Platform) } func (s *SchedulerSnapshotService) bucketFor(groupID *int64, platform string, mode string) SchedulerBucket { diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index f5ba9d71..5bfec32e 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -7,16 +7,31 @@ import ( "encoding/json" "errors" "fmt" + "log/slog" + "net/url" "strconv" "strings" + "sync/atomic" + "time" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "golang.org/x/sync/singleflight" ) var ( - ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") - ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found") + ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") + ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found") + ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found") + ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists") + ErrDefaultSubGroupInvalid = infraerrors.BadRequest( + "DEFAULT_SUBSCRIPTION_GROUP_INVALID", + "default subscription group must exist and be subscription type", + ) + ErrDefaultSubGroupDuplicate = infraerrors.BadRequest( + "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", + "default subscription group cannot be duplicated", + ) ) type SettingRepository interface { @@ -29,12 +44,40 @@ type SettingRepository interface { Delete(ctx context.Context, key string) error } +// cachedMinVersion 缓存最低 Claude Code 版本号(进程内缓存,60s TTL) +type cachedMinVersion struct { + value string // 空字符串 = 不检查 + expiresAt int64 // unix nano +} + +// minVersionCache 最低版本号进程内缓存 +var minVersionCache atomic.Value // *cachedMinVersion + +// minVersionSF 防止缓存过期时 thundering herd +var minVersionSF singleflight.Group + +// minVersionCacheTTL 缓存有效期 +const minVersionCacheTTL = 60 * time.Second + +// minVersionErrorTTL DB 错误时的短缓存,快速重试 +const minVersionErrorTTL = 5 * time.Second + +// minVersionDBTimeout singleflight 内 DB 查询超时,独立于请求 context +const minVersionDBTimeout = 5 * time.Second + +// DefaultSubscriptionGroupReader validates group references used by default subscriptions. +type DefaultSubscriptionGroupReader interface { + GetByID(ctx context.Context, id int64) (*Group, error) +} + // SettingService 系统设置服务 type SettingService struct { - settingRepo SettingRepository - cfg *config.Config - onUpdate func() // Callback when settings are updated (for cache invalidation) - version string // Application version + settingRepo SettingRepository + defaultSubGroupReader DefaultSubscriptionGroupReader + cfg *config.Config + onUpdate func() // Callback when settings are updated (for cache invalidation) + onS3Update func() // Callback when Sora S3 settings are updated + version string // Application version } // NewSettingService 创建系统设置服务实例 @@ -45,6 +88,11 @@ func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *Setti } } +// SetDefaultSubscriptionGroupReader injects an optional group reader for default subscription validation. +func (s *SettingService) SetDefaultSubscriptionGroupReader(reader DefaultSubscriptionGroupReader) { + s.defaultSubGroupReader = reader +} + // GetAllSettings 获取所有系统设置 func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) { settings, err := s.settingRepo.GetAll(ctx) @@ -60,6 +108,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings keys := []string{ SettingKeyRegistrationEnabled, SettingKeyEmailVerifyEnabled, + SettingKeyRegistrationEmailSuffixWhitelist, SettingKeyPromoCodeEnabled, SettingKeyPasswordResetEnabled, SettingKeyInvitationCodeEnabled, @@ -76,6 +125,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyHideCcsImportButton, SettingKeyPurchaseSubscriptionEnabled, SettingKeyPurchaseSubscriptionURL, + SettingKeySoraClientEnabled, + SettingKeyCustomMenuItems, SettingKeyLinuxDoConnectEnabled, } @@ -94,27 +145,33 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings // Password reset requires email verification to be enabled emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" passwordResetEnabled := emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true" + registrationEmailSuffixWhitelist := ParseRegistrationEmailSuffixWhitelist( + settings[SettingKeyRegistrationEmailSuffixWhitelist], + ) return &PublicSettings{ - RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", - EmailVerifyEnabled: emailVerifyEnabled, - PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 - PasswordResetEnabled: passwordResetEnabled, - InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", - TotpEnabled: settings[SettingKeyTotpEnabled] == "true", - TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", - TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], - SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), - SiteLogo: settings[SettingKeySiteLogo], - SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), - APIBaseURL: settings[SettingKeyAPIBaseURL], - ContactInfo: settings[SettingKeyContactInfo], - DocURL: settings[SettingKeyDocURL], - HomeContent: settings[SettingKeyHomeContent], - HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", - PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", - PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), - LinuxDoOAuthEnabled: linuxDoEnabled, + RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", + EmailVerifyEnabled: emailVerifyEnabled, + RegistrationEmailSuffixWhitelist: registrationEmailSuffixWhitelist, + PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 + PasswordResetEnabled: passwordResetEnabled, + InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", + TotpEnabled: settings[SettingKeyTotpEnabled] == "true", + TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", + TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], + SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), + SiteLogo: settings[SettingKeySiteLogo], + SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), + APIBaseURL: settings[SettingKeyAPIBaseURL], + ContactInfo: settings[SettingKeyContactInfo], + DocURL: settings[SettingKeyDocURL], + HomeContent: settings[SettingKeyHomeContent], + HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", + PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", + PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), + SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", + CustomMenuItems: settings[SettingKeyCustomMenuItems], + LinuxDoOAuthEnabled: linuxDoEnabled, }, nil } @@ -124,6 +181,11 @@ func (s *SettingService) SetOnUpdateCallback(callback func()) { s.onUpdate = callback } +// SetOnS3UpdateCallback 设置 Sora S3 配置变更时的回调函数(用于刷新 S3 客户端缓存)。 +func (s *SettingService) SetOnS3UpdateCallback(callback func()) { + s.onS3Update = callback +} + // SetVersion sets the application version for injection into public settings func (s *SettingService) SetVersion(version string) { s.version = version @@ -139,57 +201,187 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any // Return a struct that matches the frontend's expected format return &struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` - TurnstileEnabled bool `json:"turnstile_enabled"` - TurnstileSiteKey string `json:"turnstile_site_key,omitempty"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo,omitempty"` - SiteSubtitle string `json:"site_subtitle,omitempty"` - APIBaseURL string `json:"api_base_url,omitempty"` - ContactInfo string `json:"contact_info,omitempty"` - DocURL string `json:"doc_url,omitempty"` - HomeContent string `json:"home_content,omitempty"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` - LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` - Version string `json:"version,omitempty"` + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` + TurnstileEnabled bool `json:"turnstile_enabled"` + TurnstileSiteKey string `json:"turnstile_site_key,omitempty"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo,omitempty"` + SiteSubtitle string `json:"site_subtitle,omitempty"` + APIBaseURL string `json:"api_base_url,omitempty"` + ContactInfo string `json:"contact_info,omitempty"` + DocURL string `json:"doc_url,omitempty"` + HomeContent string `json:"home_content,omitempty"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` + SoraClientEnabled bool `json:"sora_client_enabled"` + CustomMenuItems json.RawMessage `json:"custom_menu_items"` + LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + Version string `json:"version,omitempty"` }{ - RegistrationEnabled: settings.RegistrationEnabled, - EmailVerifyEnabled: settings.EmailVerifyEnabled, - PromoCodeEnabled: settings.PromoCodeEnabled, - PasswordResetEnabled: settings.PasswordResetEnabled, - InvitationCodeEnabled: settings.InvitationCodeEnabled, - TotpEnabled: settings.TotpEnabled, - TurnstileEnabled: settings.TurnstileEnabled, - TurnstileSiteKey: settings.TurnstileSiteKey, - SiteName: settings.SiteName, - SiteLogo: settings.SiteLogo, - SiteSubtitle: settings.SiteSubtitle, - APIBaseURL: settings.APIBaseURL, - ContactInfo: settings.ContactInfo, - DocURL: settings.DocURL, - HomeContent: settings.HomeContent, - HideCcsImportButton: settings.HideCcsImportButton, - PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, - PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, - LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, - Version: s.version, + RegistrationEnabled: settings.RegistrationEnabled, + EmailVerifyEnabled: settings.EmailVerifyEnabled, + RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, + PromoCodeEnabled: settings.PromoCodeEnabled, + PasswordResetEnabled: settings.PasswordResetEnabled, + InvitationCodeEnabled: settings.InvitationCodeEnabled, + TotpEnabled: settings.TotpEnabled, + TurnstileEnabled: settings.TurnstileEnabled, + TurnstileSiteKey: settings.TurnstileSiteKey, + SiteName: settings.SiteName, + SiteLogo: settings.SiteLogo, + SiteSubtitle: settings.SiteSubtitle, + APIBaseURL: settings.APIBaseURL, + ContactInfo: settings.ContactInfo, + DocURL: settings.DocURL, + HomeContent: settings.HomeContent, + HideCcsImportButton: settings.HideCcsImportButton, + PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, + PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, + SoraClientEnabled: settings.SoraClientEnabled, + CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), + LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + Version: s.version, }, nil } +// filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON +// array string, returning only items with visibility != "admin". +func filterUserVisibleMenuItems(raw string) json.RawMessage { + raw = strings.TrimSpace(raw) + if raw == "" || raw == "[]" { + return json.RawMessage("[]") + } + var items []struct { + Visibility string `json:"visibility"` + } + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return json.RawMessage("[]") + } + + // Parse full items to preserve all fields + var fullItems []json.RawMessage + if err := json.Unmarshal([]byte(raw), &fullItems); err != nil { + return json.RawMessage("[]") + } + + var filtered []json.RawMessage + for i, item := range items { + if item.Visibility != "admin" { + filtered = append(filtered, fullItems[i]) + } + } + if len(filtered) == 0 { + return json.RawMessage("[]") + } + result, err := json.Marshal(filtered) + if err != nil { + return json.RawMessage("[]") + } + return result +} + +// GetFrameSrcOrigins returns deduplicated http(s) origins from purchase_subscription_url +// and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection. +func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) { + settings, err := s.GetPublicSettings(ctx) + if err != nil { + return nil, err + } + + seen := make(map[string]struct{}) + var origins []string + + addOrigin := func(rawURL string) { + if origin := extractOriginFromURL(rawURL); origin != "" { + if _, ok := seen[origin]; !ok { + seen[origin] = struct{}{} + origins = append(origins, origin) + } + } + } + + // purchase subscription URL + if settings.PurchaseSubscriptionEnabled { + addOrigin(settings.PurchaseSubscriptionURL) + } + + // all custom menu items (including admin-only, since CSP must allow all iframes) + for _, item := range parseCustomMenuItemURLs(settings.CustomMenuItems) { + addOrigin(item) + } + + return origins, nil +} + +// extractOriginFromURL returns the scheme+host origin from rawURL. +// Only http and https schemes are accepted. +func extractOriginFromURL(rawURL string) string { + rawURL = strings.TrimSpace(rawURL) + if rawURL == "" { + return "" + } + u, err := url.Parse(rawURL) + if err != nil || u.Host == "" { + return "" + } + if u.Scheme != "http" && u.Scheme != "https" { + return "" + } + return u.Scheme + "://" + u.Host +} + +// parseCustomMenuItemURLs extracts URLs from a raw JSON array of custom menu items. +func parseCustomMenuItemURLs(raw string) []string { + raw = strings.TrimSpace(raw) + if raw == "" || raw == "[]" { + return nil + } + var items []struct { + URL string `json:"url"` + } + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return nil + } + urls := make([]string, 0, len(items)) + for _, item := range items { + if item.URL != "" { + urls = append(urls, item.URL) + } + } + return urls +} + // UpdateSettings 更新系统设置 func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error { + if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil { + return err + } + normalizedWhitelist, err := NormalizeRegistrationEmailSuffixWhitelist(settings.RegistrationEmailSuffixWhitelist) + if err != nil { + return infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error()) + } + if normalizedWhitelist == nil { + normalizedWhitelist = []string{} + } + settings.RegistrationEmailSuffixWhitelist = normalizedWhitelist + updates := make(map[string]string) // 注册设置 updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled) updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled) + registrationEmailSuffixWhitelistJSON, err := json.Marshal(settings.RegistrationEmailSuffixWhitelist) + if err != nil { + return fmt.Errorf("marshal registration email suffix whitelist: %w", err) + } + updates[SettingKeyRegistrationEmailSuffixWhitelist] = string(registrationEmailSuffixWhitelistJSON) updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled) updates[SettingKeyPasswordResetEnabled] = strconv.FormatBool(settings.PasswordResetEnabled) updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled) @@ -232,10 +424,17 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton) updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled) updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL) + updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled) + updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems // 默认配置 updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64) + defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions) + if err != nil { + return fmt.Errorf("marshal default subscriptions: %w", err) + } + updates[SettingKeyDefaultSubscriptions] = string(defaultSubsJSON) // Model fallback configuration updates[SettingKeyEnableModelFallback] = strconv.FormatBool(settings.EnableModelFallback) @@ -256,13 +455,66 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyOpsMetricsIntervalSeconds] = strconv.Itoa(settings.OpsMetricsIntervalSeconds) } - err := s.settingRepo.SetMultiple(ctx, updates) - if err == nil && s.onUpdate != nil { - s.onUpdate() // Invalidate cache after settings update + // Claude Code version check + updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion + + // 分组隔离 + updates[SettingKeyAllowUngroupedKeyScheduling] = strconv.FormatBool(settings.AllowUngroupedKeyScheduling) + + err = s.settingRepo.SetMultiple(ctx, updates) + if err == nil { + // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口 + minVersionSF.Forget("min_version") + minVersionCache.Store(&cachedMinVersion{ + value: settings.MinClaudeCodeVersion, + expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(), + }) + if s.onUpdate != nil { + s.onUpdate() // Invalidate cache after settings update + } } return err } +func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error { + if len(items) == 0 { + return nil + } + + checked := make(map[int64]struct{}, len(items)) + for _, item := range items { + if item.GroupID <= 0 { + continue + } + if _, ok := checked[item.GroupID]; ok { + return ErrDefaultSubGroupDuplicate.WithMetadata(map[string]string{ + "group_id": strconv.FormatInt(item.GroupID, 10), + }) + } + checked[item.GroupID] = struct{}{} + if s.defaultSubGroupReader == nil { + continue + } + + group, err := s.defaultSubGroupReader.GetByID(ctx, item.GroupID) + if err != nil { + if errors.Is(err, ErrGroupNotFound) { + return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{ + "group_id": strconv.FormatInt(item.GroupID, 10), + }) + } + return fmt.Errorf("get default subscription group %d: %w", item.GroupID, err) + } + if !group.IsSubscriptionType() { + return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{ + "group_id": strconv.FormatInt(item.GroupID, 10), + }) + } + } + + return nil +} + // IsRegistrationEnabled 检查是否开放注册 func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool { value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled) @@ -282,6 +534,15 @@ func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool { return value == "true" } +// GetRegistrationEmailSuffixWhitelist returns normalized registration email suffix whitelist. +func (s *SettingService) GetRegistrationEmailSuffixWhitelist(ctx context.Context) []string { + value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEmailSuffixWhitelist) + if err != nil { + return []string{} + } + return ParseRegistrationEmailSuffixWhitelist(value) +} + // IsPromoCodeEnabled 检查是否启用优惠码功能 func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool { value, err := s.settingRepo.GetValue(ctx, SettingKeyPromoCodeEnabled) @@ -362,6 +623,15 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 { return s.cfg.Default.UserBalance } +// GetDefaultSubscriptions 获取新用户默认订阅配置列表。 +func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultSubscriptionSetting { + value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultSubscriptions) + if err != nil { + return nil + } + return parseDefaultSubscriptions(value) +} + // InitializeDefaultSettings 初始化默认设置 func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // 检查是否已有设置 @@ -376,17 +646,21 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // 初始化默认设置 defaults := map[string]string{ - SettingKeyRegistrationEnabled: "true", - SettingKeyEmailVerifyEnabled: "false", - SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 - SettingKeySiteName: "Sub2API", - SettingKeySiteLogo: "", - SettingKeyPurchaseSubscriptionEnabled: "false", - SettingKeyPurchaseSubscriptionURL: "", - SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), - SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), - SettingKeySMTPPort: "587", - SettingKeySMTPUseTLS: "false", + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "false", + SettingKeyRegistrationEmailSuffixWhitelist: "[]", + SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 + SettingKeySiteName: "Sub2API", + SettingKeySiteLogo: "", + SettingKeyPurchaseSubscriptionEnabled: "false", + SettingKeyPurchaseSubscriptionURL: "", + SettingKeySoraClientEnabled: "false", + SettingKeyCustomMenuItems: "[]", + SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), + SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), + SettingKeyDefaultSubscriptions: "[]", + SettingKeySMTPPort: "587", + SettingKeySMTPUseTLS: "false", // Model fallback defaults SettingKeyEnableModelFallback: "false", SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022", @@ -402,6 +676,12 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyOpsRealtimeMonitoringEnabled: "true", SettingKeyOpsQueryModeDefault: "auto", SettingKeyOpsMetricsIntervalSeconds: "60", + + // Claude Code version check (default: empty = disabled) + SettingKeyMinClaudeCodeVersion: "", + + // 分组隔离(默认不允许未分组 Key 调度) + SettingKeyAllowUngroupedKeyScheduling: "false", } return s.settingRepo.SetMultiple(ctx, defaults) @@ -411,31 +691,34 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings { emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" result := &SystemSettings{ - RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", - EmailVerifyEnabled: emailVerifyEnabled, - PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 - PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true", - InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", - TotpEnabled: settings[SettingKeyTotpEnabled] == "true", - SMTPHost: settings[SettingKeySMTPHost], - SMTPUsername: settings[SettingKeySMTPUsername], - SMTPFrom: settings[SettingKeySMTPFrom], - SMTPFromName: settings[SettingKeySMTPFromName], - SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true", - SMTPPasswordConfigured: settings[SettingKeySMTPPassword] != "", - TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", - TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], - TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "", - SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), - SiteLogo: settings[SettingKeySiteLogo], - SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), - APIBaseURL: settings[SettingKeyAPIBaseURL], - ContactInfo: settings[SettingKeyContactInfo], - DocURL: settings[SettingKeyDocURL], - HomeContent: settings[SettingKeyHomeContent], - HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", - PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", - PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), + RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", + EmailVerifyEnabled: emailVerifyEnabled, + RegistrationEmailSuffixWhitelist: ParseRegistrationEmailSuffixWhitelist(settings[SettingKeyRegistrationEmailSuffixWhitelist]), + PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 + PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true", + InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", + TotpEnabled: settings[SettingKeyTotpEnabled] == "true", + SMTPHost: settings[SettingKeySMTPHost], + SMTPUsername: settings[SettingKeySMTPUsername], + SMTPFrom: settings[SettingKeySMTPFrom], + SMTPFromName: settings[SettingKeySMTPFromName], + SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true", + SMTPPasswordConfigured: settings[SettingKeySMTPPassword] != "", + TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", + TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], + TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "", + SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), + SiteLogo: settings[SettingKeySiteLogo], + SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), + APIBaseURL: settings[SettingKeyAPIBaseURL], + ContactInfo: settings[SettingKeyContactInfo], + DocURL: settings[SettingKeyDocURL], + HomeContent: settings[SettingKeyHomeContent], + HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", + PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", + PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), + SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", + CustomMenuItems: settings[SettingKeyCustomMenuItems], } // 解析整数类型 @@ -457,6 +740,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } else { result.DefaultBalance = s.cfg.Default.UserBalance } + result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions]) // 敏感信息直接返回,方便测试连接时使用 result.SMTPPassword = settings[SettingKeySMTPPassword] @@ -526,6 +810,12 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } } + // Claude Code version check + result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion] + + // 分组隔离 + result.AllowUngroupedKeyScheduling = settings[SettingKeyAllowUngroupedKeyScheduling] == "true" + return result } @@ -538,6 +828,31 @@ func isFalseSettingValue(value string) bool { } } +func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + + var items []DefaultSubscriptionSetting + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return nil + } + + normalized := make([]DefaultSubscriptionSetting, 0, len(items)) + for _, item := range items { + if item.GroupID <= 0 || item.ValidityDays <= 0 { + continue + } + if item.ValidityDays > MaxValidityDays { + item.ValidityDays = MaxValidityDays + } + normalized = append(normalized, item) + } + + return normalized +} + // getStringOrDefault 获取字符串值或默认值 func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string { if value, ok := settings[key]; ok && value != "" { @@ -823,6 +1138,62 @@ func (s *SettingService) GetStreamTimeoutSettings(ctx context.Context) (*StreamT return &settings, nil } +// IsUngroupedKeySchedulingAllowed 查询是否允许未分组 Key 调度 +func (s *SettingService) IsUngroupedKeySchedulingAllowed(ctx context.Context) bool { + value, err := s.settingRepo.GetValue(ctx, SettingKeyAllowUngroupedKeyScheduling) + if err != nil { + return false // fail-closed: 查询失败时默认不允许 + } + return value == "true" +} + +// GetMinClaudeCodeVersion 获取最低 Claude Code 版本号要求 +// 使用进程内 atomic.Value 缓存,60 秒 TTL,热路径零锁开销 +// singleflight 防止缓存过期时 thundering herd +// 返回空字符串表示不做版本检查 +func (s *SettingService) GetMinClaudeCodeVersion(ctx context.Context) string { + if cached, ok := minVersionCache.Load().(*cachedMinVersion); ok { + if time.Now().UnixNano() < cached.expiresAt { + return cached.value + } + } + // singleflight: 同一时刻只有一个 goroutine 查询 DB,其余复用结果 + result, err, _ := minVersionSF.Do("min_version", func() (any, error) { + // 二次检查,避免排队的 goroutine 重复查询 + if cached, ok := minVersionCache.Load().(*cachedMinVersion); ok { + if time.Now().UnixNano() < cached.expiresAt { + return cached.value, nil + } + } + // 使用独立 context:断开请求取消链,避免客户端断连导致空值被长期缓存 + dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), minVersionDBTimeout) + defer cancel() + value, err := s.settingRepo.GetValue(dbCtx, SettingKeyMinClaudeCodeVersion) + if err != nil { + // fail-open: DB 错误时不阻塞请求,但记录日志并使用短 TTL 快速重试 + slog.Warn("failed to get min claude code version setting, skipping version check", "error", err) + minVersionCache.Store(&cachedMinVersion{ + value: "", + expiresAt: time.Now().Add(minVersionErrorTTL).UnixNano(), + }) + return "", nil + } + minVersionCache.Store(&cachedMinVersion{ + value: value, + expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(), + }) + return value, nil + }) + if err != nil { + return "" + } + ver, ok := result.(string) + if !ok { + return "" + } + return ver +} + // SetStreamTimeoutSettings 设置流超时处理配置 func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error { if settings == nil { @@ -854,3 +1225,607 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data)) } + +type soraS3ProfilesStore struct { + ActiveProfileID string `json:"active_profile_id"` + Items []soraS3ProfileStoreItem `json:"items"` +} + +type soraS3ProfileStoreItem struct { + ProfileID string `json:"profile_id"` + Name string `json:"name"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` + UpdatedAt string `json:"updated_at"` +} + +// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置语义:返回当前激活配置) +func (s *SettingService) GetSoraS3Settings(ctx context.Context) (*SoraS3Settings, error) { + profiles, err := s.ListSoraS3Profiles(ctx) + if err != nil { + return nil, err + } + + activeProfile := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID) + if activeProfile == nil { + return &SoraS3Settings{}, nil + } + + return &SoraS3Settings{ + Enabled: activeProfile.Enabled, + Endpoint: activeProfile.Endpoint, + Region: activeProfile.Region, + Bucket: activeProfile.Bucket, + AccessKeyID: activeProfile.AccessKeyID, + SecretAccessKey: activeProfile.SecretAccessKey, + SecretAccessKeyConfigured: activeProfile.SecretAccessKeyConfigured, + Prefix: activeProfile.Prefix, + ForcePathStyle: activeProfile.ForcePathStyle, + CDNURL: activeProfile.CDNURL, + DefaultStorageQuotaBytes: activeProfile.DefaultStorageQuotaBytes, + }, nil +} + +// SetSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置语义:写入当前激活配置) +func (s *SettingService) SetSoraS3Settings(ctx context.Context, settings *SoraS3Settings) error { + if settings == nil { + return fmt.Errorf("settings cannot be nil") + } + + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return err + } + + now := time.Now().UTC().Format(time.RFC3339) + activeIndex := findSoraS3ProfileIndex(store.Items, store.ActiveProfileID) + if activeIndex < 0 { + activeID := "default" + if hasSoraS3ProfileID(store.Items, activeID) { + activeID = fmt.Sprintf("default-%d", time.Now().Unix()) + } + store.Items = append(store.Items, soraS3ProfileStoreItem{ + ProfileID: activeID, + Name: "Default", + UpdatedAt: now, + }) + store.ActiveProfileID = activeID + activeIndex = len(store.Items) - 1 + } + + active := store.Items[activeIndex] + active.Enabled = settings.Enabled + active.Endpoint = strings.TrimSpace(settings.Endpoint) + active.Region = strings.TrimSpace(settings.Region) + active.Bucket = strings.TrimSpace(settings.Bucket) + active.AccessKeyID = strings.TrimSpace(settings.AccessKeyID) + active.Prefix = strings.TrimSpace(settings.Prefix) + active.ForcePathStyle = settings.ForcePathStyle + active.CDNURL = strings.TrimSpace(settings.CDNURL) + active.DefaultStorageQuotaBytes = maxInt64(settings.DefaultStorageQuotaBytes, 0) + if settings.SecretAccessKey != "" { + active.SecretAccessKey = settings.SecretAccessKey + } + active.UpdatedAt = now + store.Items[activeIndex] = active + + return s.persistSoraS3ProfilesStore(ctx, store) +} + +// ListSoraS3Profiles 获取 Sora S3 多配置列表 +func (s *SettingService) ListSoraS3Profiles(ctx context.Context) (*SoraS3ProfileList, error) { + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return nil, err + } + return convertSoraS3ProfilesStore(store), nil +} + +// CreateSoraS3Profile 创建 Sora S3 配置 +func (s *SettingService) CreateSoraS3Profile(ctx context.Context, profile *SoraS3Profile, setActive bool) (*SoraS3Profile, error) { + if profile == nil { + return nil, fmt.Errorf("profile cannot be nil") + } + + profileID := strings.TrimSpace(profile.ProfileID) + if profileID == "" { + return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") + } + name := strings.TrimSpace(profile.Name) + if name == "" { + return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required") + } + + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return nil, err + } + if hasSoraS3ProfileID(store.Items, profileID) { + return nil, ErrSoraS3ProfileExists + } + + now := time.Now().UTC().Format(time.RFC3339) + store.Items = append(store.Items, soraS3ProfileStoreItem{ + ProfileID: profileID, + Name: name, + Enabled: profile.Enabled, + Endpoint: strings.TrimSpace(profile.Endpoint), + Region: strings.TrimSpace(profile.Region), + Bucket: strings.TrimSpace(profile.Bucket), + AccessKeyID: strings.TrimSpace(profile.AccessKeyID), + SecretAccessKey: profile.SecretAccessKey, + Prefix: strings.TrimSpace(profile.Prefix), + ForcePathStyle: profile.ForcePathStyle, + CDNURL: strings.TrimSpace(profile.CDNURL), + DefaultStorageQuotaBytes: maxInt64(profile.DefaultStorageQuotaBytes, 0), + UpdatedAt: now, + }) + + if setActive || store.ActiveProfileID == "" { + store.ActiveProfileID = profileID + } + + if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil { + return nil, err + } + + profiles := convertSoraS3ProfilesStore(store) + created := findSoraS3ProfileByID(profiles.Items, profileID) + if created == nil { + return nil, ErrSoraS3ProfileNotFound + } + return created, nil +} + +// UpdateSoraS3Profile 更新 Sora S3 配置 +func (s *SettingService) UpdateSoraS3Profile(ctx context.Context, profileID string, profile *SoraS3Profile) (*SoraS3Profile, error) { + if profile == nil { + return nil, fmt.Errorf("profile cannot be nil") + } + + targetID := strings.TrimSpace(profileID) + if targetID == "" { + return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") + } + + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return nil, err + } + + targetIndex := findSoraS3ProfileIndex(store.Items, targetID) + if targetIndex < 0 { + return nil, ErrSoraS3ProfileNotFound + } + + target := store.Items[targetIndex] + name := strings.TrimSpace(profile.Name) + if name == "" { + return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required") + } + target.Name = name + target.Enabled = profile.Enabled + target.Endpoint = strings.TrimSpace(profile.Endpoint) + target.Region = strings.TrimSpace(profile.Region) + target.Bucket = strings.TrimSpace(profile.Bucket) + target.AccessKeyID = strings.TrimSpace(profile.AccessKeyID) + target.Prefix = strings.TrimSpace(profile.Prefix) + target.ForcePathStyle = profile.ForcePathStyle + target.CDNURL = strings.TrimSpace(profile.CDNURL) + target.DefaultStorageQuotaBytes = maxInt64(profile.DefaultStorageQuotaBytes, 0) + if profile.SecretAccessKey != "" { + target.SecretAccessKey = profile.SecretAccessKey + } + target.UpdatedAt = time.Now().UTC().Format(time.RFC3339) + store.Items[targetIndex] = target + + if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil { + return nil, err + } + + profiles := convertSoraS3ProfilesStore(store) + updated := findSoraS3ProfileByID(profiles.Items, targetID) + if updated == nil { + return nil, ErrSoraS3ProfileNotFound + } + return updated, nil +} + +// DeleteSoraS3Profile 删除 Sora S3 配置 +func (s *SettingService) DeleteSoraS3Profile(ctx context.Context, profileID string) error { + targetID := strings.TrimSpace(profileID) + if targetID == "" { + return infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") + } + + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return err + } + + targetIndex := findSoraS3ProfileIndex(store.Items, targetID) + if targetIndex < 0 { + return ErrSoraS3ProfileNotFound + } + + store.Items = append(store.Items[:targetIndex], store.Items[targetIndex+1:]...) + if store.ActiveProfileID == targetID { + store.ActiveProfileID = "" + if len(store.Items) > 0 { + store.ActiveProfileID = store.Items[0].ProfileID + } + } + + return s.persistSoraS3ProfilesStore(ctx, store) +} + +// SetActiveSoraS3Profile 设置激活的 Sora S3 配置 +func (s *SettingService) SetActiveSoraS3Profile(ctx context.Context, profileID string) (*SoraS3Profile, error) { + targetID := strings.TrimSpace(profileID) + if targetID == "" { + return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") + } + + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return nil, err + } + + targetIndex := findSoraS3ProfileIndex(store.Items, targetID) + if targetIndex < 0 { + return nil, ErrSoraS3ProfileNotFound + } + + store.ActiveProfileID = targetID + store.Items[targetIndex].UpdatedAt = time.Now().UTC().Format(time.RFC3339) + if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil { + return nil, err + } + + profiles := convertSoraS3ProfilesStore(store) + active := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID) + if active == nil { + return nil, ErrSoraS3ProfileNotFound + } + return active, nil +} + +func (s *SettingService) loadSoraS3ProfilesStore(ctx context.Context) (*soraS3ProfilesStore, error) { + raw, err := s.settingRepo.GetValue(ctx, SettingKeySoraS3Profiles) + if err == nil { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return &soraS3ProfilesStore{}, nil + } + var store soraS3ProfilesStore + if unmarshalErr := json.Unmarshal([]byte(trimmed), &store); unmarshalErr != nil { + legacy, legacyErr := s.getLegacySoraS3Settings(ctx) + if legacyErr != nil { + return nil, fmt.Errorf("unmarshal sora s3 profiles: %w", unmarshalErr) + } + if isEmptyLegacySoraS3Settings(legacy) { + return &soraS3ProfilesStore{}, nil + } + now := time.Now().UTC().Format(time.RFC3339) + return &soraS3ProfilesStore{ + ActiveProfileID: "default", + Items: []soraS3ProfileStoreItem{ + { + ProfileID: "default", + Name: "Default", + Enabled: legacy.Enabled, + Endpoint: strings.TrimSpace(legacy.Endpoint), + Region: strings.TrimSpace(legacy.Region), + Bucket: strings.TrimSpace(legacy.Bucket), + AccessKeyID: strings.TrimSpace(legacy.AccessKeyID), + SecretAccessKey: legacy.SecretAccessKey, + Prefix: strings.TrimSpace(legacy.Prefix), + ForcePathStyle: legacy.ForcePathStyle, + CDNURL: strings.TrimSpace(legacy.CDNURL), + DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0), + UpdatedAt: now, + }, + }, + }, nil + } + normalized := normalizeSoraS3ProfilesStore(store) + return &normalized, nil + } + + if !errors.Is(err, ErrSettingNotFound) { + return nil, fmt.Errorf("get sora s3 profiles: %w", err) + } + + legacy, legacyErr := s.getLegacySoraS3Settings(ctx) + if legacyErr != nil { + return nil, legacyErr + } + if isEmptyLegacySoraS3Settings(legacy) { + return &soraS3ProfilesStore{}, nil + } + + now := time.Now().UTC().Format(time.RFC3339) + return &soraS3ProfilesStore{ + ActiveProfileID: "default", + Items: []soraS3ProfileStoreItem{ + { + ProfileID: "default", + Name: "Default", + Enabled: legacy.Enabled, + Endpoint: strings.TrimSpace(legacy.Endpoint), + Region: strings.TrimSpace(legacy.Region), + Bucket: strings.TrimSpace(legacy.Bucket), + AccessKeyID: strings.TrimSpace(legacy.AccessKeyID), + SecretAccessKey: legacy.SecretAccessKey, + Prefix: strings.TrimSpace(legacy.Prefix), + ForcePathStyle: legacy.ForcePathStyle, + CDNURL: strings.TrimSpace(legacy.CDNURL), + DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0), + UpdatedAt: now, + }, + }, + }, nil +} + +func (s *SettingService) persistSoraS3ProfilesStore(ctx context.Context, store *soraS3ProfilesStore) error { + if store == nil { + return fmt.Errorf("sora s3 profiles store cannot be nil") + } + + normalized := normalizeSoraS3ProfilesStore(*store) + data, err := json.Marshal(normalized) + if err != nil { + return fmt.Errorf("marshal sora s3 profiles: %w", err) + } + + updates := map[string]string{ + SettingKeySoraS3Profiles: string(data), + } + + active := pickActiveSoraS3ProfileFromStore(normalized.Items, normalized.ActiveProfileID) + if active == nil { + updates[SettingKeySoraS3Enabled] = "false" + updates[SettingKeySoraS3Endpoint] = "" + updates[SettingKeySoraS3Region] = "" + updates[SettingKeySoraS3Bucket] = "" + updates[SettingKeySoraS3AccessKeyID] = "" + updates[SettingKeySoraS3Prefix] = "" + updates[SettingKeySoraS3ForcePathStyle] = "false" + updates[SettingKeySoraS3CDNURL] = "" + updates[SettingKeySoraDefaultStorageQuotaBytes] = "0" + updates[SettingKeySoraS3SecretAccessKey] = "" + } else { + updates[SettingKeySoraS3Enabled] = strconv.FormatBool(active.Enabled) + updates[SettingKeySoraS3Endpoint] = strings.TrimSpace(active.Endpoint) + updates[SettingKeySoraS3Region] = strings.TrimSpace(active.Region) + updates[SettingKeySoraS3Bucket] = strings.TrimSpace(active.Bucket) + updates[SettingKeySoraS3AccessKeyID] = strings.TrimSpace(active.AccessKeyID) + updates[SettingKeySoraS3Prefix] = strings.TrimSpace(active.Prefix) + updates[SettingKeySoraS3ForcePathStyle] = strconv.FormatBool(active.ForcePathStyle) + updates[SettingKeySoraS3CDNURL] = strings.TrimSpace(active.CDNURL) + updates[SettingKeySoraDefaultStorageQuotaBytes] = strconv.FormatInt(maxInt64(active.DefaultStorageQuotaBytes, 0), 10) + updates[SettingKeySoraS3SecretAccessKey] = active.SecretAccessKey + } + + if err := s.settingRepo.SetMultiple(ctx, updates); err != nil { + return err + } + + if s.onUpdate != nil { + s.onUpdate() + } + if s.onS3Update != nil { + s.onS3Update() + } + return nil +} + +func (s *SettingService) getLegacySoraS3Settings(ctx context.Context) (*SoraS3Settings, error) { + keys := []string{ + SettingKeySoraS3Enabled, + SettingKeySoraS3Endpoint, + SettingKeySoraS3Region, + SettingKeySoraS3Bucket, + SettingKeySoraS3AccessKeyID, + SettingKeySoraS3SecretAccessKey, + SettingKeySoraS3Prefix, + SettingKeySoraS3ForcePathStyle, + SettingKeySoraS3CDNURL, + SettingKeySoraDefaultStorageQuotaBytes, + } + + values, err := s.settingRepo.GetMultiple(ctx, keys) + if err != nil { + return nil, fmt.Errorf("get legacy sora s3 settings: %w", err) + } + + result := &SoraS3Settings{ + Enabled: values[SettingKeySoraS3Enabled] == "true", + Endpoint: values[SettingKeySoraS3Endpoint], + Region: values[SettingKeySoraS3Region], + Bucket: values[SettingKeySoraS3Bucket], + AccessKeyID: values[SettingKeySoraS3AccessKeyID], + SecretAccessKey: values[SettingKeySoraS3SecretAccessKey], + SecretAccessKeyConfigured: values[SettingKeySoraS3SecretAccessKey] != "", + Prefix: values[SettingKeySoraS3Prefix], + ForcePathStyle: values[SettingKeySoraS3ForcePathStyle] == "true", + CDNURL: values[SettingKeySoraS3CDNURL], + } + if v, parseErr := strconv.ParseInt(values[SettingKeySoraDefaultStorageQuotaBytes], 10, 64); parseErr == nil { + result.DefaultStorageQuotaBytes = v + } + return result, nil +} + +func normalizeSoraS3ProfilesStore(store soraS3ProfilesStore) soraS3ProfilesStore { + seen := make(map[string]struct{}, len(store.Items)) + normalized := soraS3ProfilesStore{ + ActiveProfileID: strings.TrimSpace(store.ActiveProfileID), + Items: make([]soraS3ProfileStoreItem, 0, len(store.Items)), + } + now := time.Now().UTC().Format(time.RFC3339) + + for idx := range store.Items { + item := store.Items[idx] + item.ProfileID = strings.TrimSpace(item.ProfileID) + if item.ProfileID == "" { + item.ProfileID = fmt.Sprintf("profile-%d", idx+1) + } + if _, exists := seen[item.ProfileID]; exists { + continue + } + seen[item.ProfileID] = struct{}{} + + item.Name = strings.TrimSpace(item.Name) + if item.Name == "" { + item.Name = item.ProfileID + } + item.Endpoint = strings.TrimSpace(item.Endpoint) + item.Region = strings.TrimSpace(item.Region) + item.Bucket = strings.TrimSpace(item.Bucket) + item.AccessKeyID = strings.TrimSpace(item.AccessKeyID) + item.Prefix = strings.TrimSpace(item.Prefix) + item.CDNURL = strings.TrimSpace(item.CDNURL) + item.DefaultStorageQuotaBytes = maxInt64(item.DefaultStorageQuotaBytes, 0) + item.UpdatedAt = strings.TrimSpace(item.UpdatedAt) + if item.UpdatedAt == "" { + item.UpdatedAt = now + } + normalized.Items = append(normalized.Items, item) + } + + if len(normalized.Items) == 0 { + normalized.ActiveProfileID = "" + return normalized + } + + if findSoraS3ProfileIndex(normalized.Items, normalized.ActiveProfileID) >= 0 { + return normalized + } + + normalized.ActiveProfileID = normalized.Items[0].ProfileID + return normalized +} + +func convertSoraS3ProfilesStore(store *soraS3ProfilesStore) *SoraS3ProfileList { + if store == nil { + return &SoraS3ProfileList{} + } + items := make([]SoraS3Profile, 0, len(store.Items)) + for idx := range store.Items { + item := store.Items[idx] + items = append(items, SoraS3Profile{ + ProfileID: item.ProfileID, + Name: item.Name, + IsActive: item.ProfileID == store.ActiveProfileID, + Enabled: item.Enabled, + Endpoint: item.Endpoint, + Region: item.Region, + Bucket: item.Bucket, + AccessKeyID: item.AccessKeyID, + SecretAccessKey: item.SecretAccessKey, + SecretAccessKeyConfigured: item.SecretAccessKey != "", + Prefix: item.Prefix, + ForcePathStyle: item.ForcePathStyle, + CDNURL: item.CDNURL, + DefaultStorageQuotaBytes: item.DefaultStorageQuotaBytes, + UpdatedAt: item.UpdatedAt, + }) + } + return &SoraS3ProfileList{ + ActiveProfileID: store.ActiveProfileID, + Items: items, + } +} + +func pickActiveSoraS3Profile(items []SoraS3Profile, activeProfileID string) *SoraS3Profile { + for idx := range items { + if items[idx].ProfileID == activeProfileID { + return &items[idx] + } + } + if len(items) == 0 { + return nil + } + return &items[0] +} + +func findSoraS3ProfileByID(items []SoraS3Profile, profileID string) *SoraS3Profile { + for idx := range items { + if items[idx].ProfileID == profileID { + return &items[idx] + } + } + return nil +} + +func pickActiveSoraS3ProfileFromStore(items []soraS3ProfileStoreItem, activeProfileID string) *soraS3ProfileStoreItem { + for idx := range items { + if items[idx].ProfileID == activeProfileID { + return &items[idx] + } + } + if len(items) == 0 { + return nil + } + return &items[0] +} + +func findSoraS3ProfileIndex(items []soraS3ProfileStoreItem, profileID string) int { + for idx := range items { + if items[idx].ProfileID == profileID { + return idx + } + } + return -1 +} + +func hasSoraS3ProfileID(items []soraS3ProfileStoreItem, profileID string) bool { + return findSoraS3ProfileIndex(items, profileID) >= 0 +} + +func isEmptyLegacySoraS3Settings(settings *SoraS3Settings) bool { + if settings == nil { + return true + } + if settings.Enabled { + return false + } + if strings.TrimSpace(settings.Endpoint) != "" { + return false + } + if strings.TrimSpace(settings.Region) != "" { + return false + } + if strings.TrimSpace(settings.Bucket) != "" { + return false + } + if strings.TrimSpace(settings.AccessKeyID) != "" { + return false + } + if settings.SecretAccessKey != "" { + return false + } + if strings.TrimSpace(settings.Prefix) != "" { + return false + } + if strings.TrimSpace(settings.CDNURL) != "" { + return false + } + return settings.DefaultStorageQuotaBytes == 0 +} + +func maxInt64(value int64, min int64) int64 { + if value < min { + return min + } + return value +} diff --git a/backend/internal/service/setting_service_public_test.go b/backend/internal/service/setting_service_public_test.go new file mode 100644 index 00000000..b511cd29 --- /dev/null +++ b/backend/internal/service/setting_service_public_test.go @@ -0,0 +1,64 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type settingPublicRepoStub struct { + values map[string]string +} + +func (s *settingPublicRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *settingPublicRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *settingPublicRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *settingPublicRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *settingPublicRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *settingPublicRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *settingPublicRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestSettingService_GetPublicSettings_ExposesRegistrationEmailSuffixWhitelist(t *testing.T) { + repo := &settingPublicRepoStub{ + values: map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "true", + SettingKeyRegistrationEmailSuffixWhitelist: `["@EXAMPLE.com"," @foo.bar ","@invalid_domain",""]`, + }, + } + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetPublicSettings(context.Background()) + require.NoError(t, err) + require.Equal(t, []string{"@example.com", "@foo.bar"}, settings.RegistrationEmailSuffixWhitelist) +} diff --git a/backend/internal/service/setting_service_update_test.go b/backend/internal/service/setting_service_update_test.go new file mode 100644 index 00000000..1de08611 --- /dev/null +++ b/backend/internal/service/setting_service_update_test.go @@ -0,0 +1,204 @@ +//go:build unit + +package service + +import ( + "context" + "encoding/json" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +type settingUpdateRepoStub struct { + updates map[string]string +} + +func (s *settingUpdateRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *settingUpdateRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *settingUpdateRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *settingUpdateRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + panic("unexpected GetMultiple call") +} + +func (s *settingUpdateRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + s.updates = make(map[string]string, len(settings)) + for k, v := range settings { + s.updates[k] = v + } + return nil +} + +func (s *settingUpdateRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *settingUpdateRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +type defaultSubGroupReaderStub struct { + byID map[int64]*Group + errBy map[int64]error + calls []int64 +} + +func (s *defaultSubGroupReaderStub) GetByID(ctx context.Context, id int64) (*Group, error) { + s.calls = append(s.calls, id) + if err, ok := s.errBy[id]; ok { + return nil, err + } + if g, ok := s.byID[id]; ok { + return g, nil + } + return nil, ErrGroupNotFound +} + +func TestSettingService_UpdateSettings_DefaultSubscriptions_ValidGroup(t *testing.T) { + repo := &settingUpdateRepoStub{} + groupReader := &defaultSubGroupReaderStub{ + byID: map[int64]*Group{ + 11: {ID: 11, SubscriptionType: SubscriptionTypeSubscription}, + }, + } + svc := NewSettingService(repo, &config.Config{}) + svc.SetDefaultSubscriptionGroupReader(groupReader) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + DefaultSubscriptions: []DefaultSubscriptionSetting{ + {GroupID: 11, ValidityDays: 30}, + }, + }) + require.NoError(t, err) + require.Equal(t, []int64{11}, groupReader.calls) + + raw, ok := repo.updates[SettingKeyDefaultSubscriptions] + require.True(t, ok) + + var got []DefaultSubscriptionSetting + require.NoError(t, json.Unmarshal([]byte(raw), &got)) + require.Equal(t, []DefaultSubscriptionSetting{ + {GroupID: 11, ValidityDays: 30}, + }, got) +} + +func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsNonSubscriptionGroup(t *testing.T) { + repo := &settingUpdateRepoStub{} + groupReader := &defaultSubGroupReaderStub{ + byID: map[int64]*Group{ + 12: {ID: 12, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := NewSettingService(repo, &config.Config{}) + svc.SetDefaultSubscriptionGroupReader(groupReader) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + DefaultSubscriptions: []DefaultSubscriptionSetting{ + {GroupID: 12, ValidityDays: 7}, + }, + }) + require.Error(t, err) + require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_INVALID", infraerrors.Reason(err)) + require.Nil(t, repo.updates) +} + +func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsNotFoundGroup(t *testing.T) { + repo := &settingUpdateRepoStub{} + groupReader := &defaultSubGroupReaderStub{ + errBy: map[int64]error{ + 13: ErrGroupNotFound, + }, + } + svc := NewSettingService(repo, &config.Config{}) + svc.SetDefaultSubscriptionGroupReader(groupReader) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + DefaultSubscriptions: []DefaultSubscriptionSetting{ + {GroupID: 13, ValidityDays: 7}, + }, + }) + require.Error(t, err) + require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_INVALID", infraerrors.Reason(err)) + require.Equal(t, "13", infraerrors.FromError(err).Metadata["group_id"]) + require.Nil(t, repo.updates) +} + +func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGroup(t *testing.T) { + repo := &settingUpdateRepoStub{} + groupReader := &defaultSubGroupReaderStub{ + byID: map[int64]*Group{ + 11: {ID: 11, SubscriptionType: SubscriptionTypeSubscription}, + }, + } + svc := NewSettingService(repo, &config.Config{}) + svc.SetDefaultSubscriptionGroupReader(groupReader) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + DefaultSubscriptions: []DefaultSubscriptionSetting{ + {GroupID: 11, ValidityDays: 30}, + {GroupID: 11, ValidityDays: 60}, + }, + }) + require.Error(t, err) + require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", infraerrors.Reason(err)) + require.Equal(t, "11", infraerrors.FromError(err).Metadata["group_id"]) + require.Nil(t, repo.updates) +} + +func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGroupWithoutGroupReader(t *testing.T) { + repo := &settingUpdateRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + DefaultSubscriptions: []DefaultSubscriptionSetting{ + {GroupID: 11, ValidityDays: 30}, + {GroupID: 11, ValidityDays: 60}, + }, + }) + require.Error(t, err) + require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", infraerrors.Reason(err)) + require.Equal(t, "11", infraerrors.FromError(err).Metadata["group_id"]) + require.Nil(t, repo.updates) +} + +func TestSettingService_UpdateSettings_RegistrationEmailSuffixWhitelist_Normalized(t *testing.T) { + repo := &settingUpdateRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + RegistrationEmailSuffixWhitelist: []string{"example.com", "@EXAMPLE.com", " @foo.bar "}, + }) + require.NoError(t, err) + require.Equal(t, `["@example.com","@foo.bar"]`, repo.updates[SettingKeyRegistrationEmailSuffixWhitelist]) +} + +func TestSettingService_UpdateSettings_RegistrationEmailSuffixWhitelist_Invalid(t *testing.T) { + repo := &settingUpdateRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + RegistrationEmailSuffixWhitelist: []string{"@invalid_domain"}, + }) + require.Error(t, err) + require.Equal(t, "INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", infraerrors.Reason(err)) +} + +func TestParseDefaultSubscriptions_NormalizesValues(t *testing.T) { + got := parseDefaultSubscriptions(`[{"group_id":11,"validity_days":30},{"group_id":11,"validity_days":60},{"group_id":0,"validity_days":10},{"group_id":12,"validity_days":99999}]`) + require.Equal(t, []DefaultSubscriptionSetting{ + {GroupID: 11, ValidityDays: 30}, + {GroupID: 11, ValidityDays: 60}, + {GroupID: 12, ValidityDays: MaxValidityDays}, + }, got) +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 0c7bab67..6a1d62d8 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -1,12 +1,13 @@ package service type SystemSettings struct { - RegistrationEnabled bool - EmailVerifyEnabled bool - PromoCodeEnabled bool - PasswordResetEnabled bool - InvitationCodeEnabled bool - TotpEnabled bool // TOTP 双因素认证 + RegistrationEnabled bool + EmailVerifyEnabled bool + RegistrationEmailSuffixWhitelist []string + PromoCodeEnabled bool + PasswordResetEnabled bool + InvitationCodeEnabled bool + TotpEnabled bool // TOTP 双因素认证 SMTPHost string SMTPPort int @@ -39,9 +40,12 @@ type SystemSettings struct { HideCcsImportButton bool PurchaseSubscriptionEnabled bool PurchaseSubscriptionURL string + SoraClientEnabled bool + CustomMenuItems string // JSON array of custom menu items - DefaultConcurrency int - DefaultBalance float64 + DefaultConcurrency int + DefaultBalance float64 + DefaultSubscriptions []DefaultSubscriptionSetting // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` @@ -59,33 +63,87 @@ type SystemSettings struct { OpsRealtimeMonitoringEnabled bool OpsQueryModeDefault string OpsMetricsIntervalSeconds int + + // Claude Code version check + MinClaudeCodeVersion string + + // 分组隔离:允许未分组 Key 调度(默认 false → 403) + AllowUngroupedKeyScheduling bool +} + +type DefaultSubscriptionSetting struct { + GroupID int64 `json:"group_id"` + ValidityDays int `json:"validity_days"` } type PublicSettings struct { - RegistrationEnabled bool - EmailVerifyEnabled bool - PromoCodeEnabled bool - PasswordResetEnabled bool - InvitationCodeEnabled bool - TotpEnabled bool // TOTP 双因素认证 - TurnstileEnabled bool - TurnstileSiteKey string - SiteName string - SiteLogo string - SiteSubtitle string - APIBaseURL string - ContactInfo string - DocURL string - HomeContent string - HideCcsImportButton bool + RegistrationEnabled bool + EmailVerifyEnabled bool + RegistrationEmailSuffixWhitelist []string + PromoCodeEnabled bool + PasswordResetEnabled bool + InvitationCodeEnabled bool + TotpEnabled bool // TOTP 双因素认证 + TurnstileEnabled bool + TurnstileSiteKey string + SiteName string + SiteLogo string + SiteSubtitle string + APIBaseURL string + ContactInfo string + DocURL string + HomeContent string + HideCcsImportButton bool PurchaseSubscriptionEnabled bool PurchaseSubscriptionURL string + SoraClientEnabled bool + CustomMenuItems string // JSON array of custom menu items LinuxDoOAuthEnabled bool Version string } +// SoraS3Settings Sora S3 存储配置 +type SoraS3Settings struct { + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` // 仅内部使用,不直接返回前端 + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用 + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` +} + +// SoraS3Profile Sora S3 多配置项(服务内部模型) +type SoraS3Profile struct { + ProfileID string `json:"profile_id"` + Name string `json:"name"` + IsActive bool `json:"is_active"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"-"` // 仅内部使用,不直接返回前端 + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用 + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` + UpdatedAt string `json:"updated_at"` +} + +// SoraS3ProfileList Sora S3 多配置列表 +type SoraS3ProfileList struct { + ActiveProfileID string `json:"active_profile_id"` + Items []SoraS3Profile `json:"items"` +} + // StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制) type StreamTimeoutSettings struct { // Enabled 是否启用流超时处理 diff --git a/backend/internal/service/sora_account_service.go b/backend/internal/service/sora_account_service.go new file mode 100644 index 00000000..eccc1acf --- /dev/null +++ b/backend/internal/service/sora_account_service.go @@ -0,0 +1,40 @@ +package service + +import "context" + +// SoraAccountRepository Sora 账号扩展表仓储接口 +// 用于管理 sora_accounts 表,与 accounts 主表形成双表结构。 +// +// 设计说明: +// - sora_accounts 表存储 Sora 账号的 OAuth 凭证副本 +// - Sora gateway 优先读取此表的字段以获得更好的查询性能 +// - 主表 accounts 通过 credentials JSON 字段也存储相同信息 +// - Token 刷新时需要同时更新两个表以保持数据一致性 +type SoraAccountRepository interface { + // Upsert 创建或更新 Sora 账号扩展信息 + // accountID: 关联的 accounts.id + // updates: 要更新的字段,支持 access_token、refresh_token、session_token + // + // 如果记录不存在则创建,存在则更新。 + // 用于: + // 1. 创建 Sora 账号时初始化扩展表 + // 2. Token 刷新时同步更新扩展表 + Upsert(ctx context.Context, accountID int64, updates map[string]any) error + + // GetByAccountID 根据账号 ID 获取 Sora 扩展信息 + // 返回 nil, nil 表示记录不存在(非错误) + GetByAccountID(ctx context.Context, accountID int64) (*SoraAccount, error) + + // Delete 删除 Sora 账号扩展信息 + // 通常由外键 ON DELETE CASCADE 自动处理,此方法用于手动清理 + Delete(ctx context.Context, accountID int64) error +} + +// SoraAccount Sora 账号扩展信息 +// 对应 sora_accounts 表,存储 Sora 账号的 OAuth 凭证副本 +type SoraAccount struct { + AccountID int64 // 关联的 accounts.id + AccessToken string // OAuth access_token + RefreshToken string // OAuth refresh_token + SessionToken string // Session token(可选,用于 ST→AT 兜底) +} diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go new file mode 100644 index 00000000..0a914d2d --- /dev/null +++ b/backend/internal/service/sora_client.go @@ -0,0 +1,117 @@ +package service + +import ( + "context" + "fmt" + "net/http" +) + +// SoraClient 定义直连 Sora 的任务操作接口。 +type SoraClient interface { + Enabled() bool + UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) + CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) + CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) + CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) + UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) + GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) + DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) + UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) + FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) + SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error + DeleteCharacter(ctx context.Context, account *Account, characterID string) error + PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) + DeletePost(ctx context.Context, account *Account, postID string) error + GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) + EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) + GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) + GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) +} + +// SoraImageRequest 图片生成请求参数 +type SoraImageRequest struct { + Prompt string + Width int + Height int + MediaID string +} + +// SoraVideoRequest 视频生成请求参数 +type SoraVideoRequest struct { + Prompt string + Orientation string + Frames int + Model string + Size string + VideoCount int + MediaID string + RemixTargetID string + CameoIDs []string +} + +// SoraStoryboardRequest 分镜视频生成请求参数 +type SoraStoryboardRequest struct { + Prompt string + Orientation string + Frames int + Model string + Size string + MediaID string +} + +// SoraImageTaskStatus 图片任务状态 +type SoraImageTaskStatus struct { + ID string + Status string + ProgressPct float64 + URLs []string + ErrorMsg string +} + +// SoraVideoTaskStatus 视频任务状态 +type SoraVideoTaskStatus struct { + ID string + Status string + ProgressPct int + URLs []string + GenerationID string + ErrorMsg string +} + +// SoraCameoStatus 角色处理中间态 +type SoraCameoStatus struct { + Status string + StatusMessage string + DisplayNameHint string + UsernameHint string + ProfileAssetURL string + InstructionSetHint any + InstructionSet any +} + +// SoraCharacterFinalizeRequest 角色定稿请求参数 +type SoraCharacterFinalizeRequest struct { + CameoID string + Username string + DisplayName string + ProfileAssetPointer string + InstructionSet any +} + +// SoraUpstreamError 上游错误 +type SoraUpstreamError struct { + StatusCode int + Message string + Headers http.Header + Body []byte +} + +func (e *SoraUpstreamError) Error() string { + if e == nil { + return "sora upstream error" + } + if e.Message != "" { + return fmt.Sprintf("sora upstream error: %d %s", e.StatusCode, e.Message) + } + return fmt.Sprintf("sora upstream error: %d", e.StatusCode) +} diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go new file mode 100644 index 00000000..ab6871bb --- /dev/null +++ b/backend/internal/service/sora_gateway_service.go @@ -0,0 +1,1553 @@ +package service + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "math" + "math/rand" + "mime" + "net" + "net/http" + "net/url" + "regexp" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/gin-gonic/gin" +) + +const soraImageInputMaxBytes = 20 << 20 +const soraImageInputMaxRedirects = 3 +const soraImageInputTimeout = 20 * time.Second +const soraVideoInputMaxBytes = 200 << 20 +const soraVideoInputMaxRedirects = 3 +const soraVideoInputTimeout = 60 * time.Second + +var soraImageSizeMap = map[string]string{ + "gpt-image": "360", + "gpt-image-landscape": "540", + "gpt-image-portrait": "540", +} + +var soraBlockedHostnames = map[string]struct{}{ + "localhost": {}, + "localhost.localdomain": {}, + "metadata.google.internal": {}, + "metadata.google.internal.": {}, +} + +var soraBlockedCIDRs = mustParseCIDRs([]string{ + "0.0.0.0/8", + "10.0.0.0/8", + "100.64.0.0/10", + "127.0.0.0/8", + "169.254.0.0/16", + "172.16.0.0/12", + "192.168.0.0/16", + "224.0.0.0/4", + "240.0.0.0/4", + "::/128", + "::1/128", + "fc00::/7", + "fe80::/10", +}) + +// SoraGatewayService handles forwarding requests to Sora upstream. +type SoraGatewayService struct { + soraClient SoraClient + rateLimitService *RateLimitService + httpUpstream HTTPUpstream // 用于 apikey 类型账号的 HTTP 透传 + cfg *config.Config +} + +type soraWatermarkOptions struct { + Enabled bool + ParseMethod string + ParseURL string + ParseToken string + FallbackOnFailure bool + DeletePost bool +} + +type soraCharacterOptions struct { + SetPublic bool + DeleteAfterGenerate bool +} + +type soraCharacterFlowResult struct { + CameoID string + CharacterID string + Username string + DisplayName string +} + +var soraStoryboardPattern = regexp.MustCompile(`\[\d+(?:\.\d+)?s\]`) +var soraStoryboardShotPattern = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`) +var soraRemixTargetPattern = regexp.MustCompile(`s_[a-f0-9]{32}`) +var soraRemixTargetInURLPattern = regexp.MustCompile(`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`) + +type soraPreflightChecker interface { + PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error +} + +func NewSoraGatewayService( + soraClient SoraClient, + rateLimitService *RateLimitService, + httpUpstream HTTPUpstream, + cfg *config.Config, +) *SoraGatewayService { + return &SoraGatewayService{ + soraClient: soraClient, + rateLimitService: rateLimitService, + httpUpstream: httpUpstream, + cfg: cfg, + } +} + +func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) { + startTime := time.Now() + + // apikey 类型账号:HTTP 透传到上游,不走 SoraSDKClient + if account.Type == AccountTypeAPIKey && account.GetBaseURL() != "" { + if s.httpUpstream == nil { + s.writeSoraError(c, http.StatusInternalServerError, "api_error", "HTTP upstream client not configured", clientStream) + return nil, errors.New("httpUpstream not configured for sora apikey forwarding") + } + return s.forwardToUpstream(ctx, c, account, body, clientStream, startTime) + } + + if s.soraClient == nil || !s.soraClient.Enabled() { + if c != nil { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": gin.H{ + "type": "api_error", + "message": "Sora 上游未配置", + }, + }) + } + return nil, errors.New("sora upstream not configured") + } + + var reqBody map[string]any + if err := json.Unmarshal(body, &reqBody); err != nil { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body", clientStream) + return nil, fmt.Errorf("parse request: %w", err) + } + reqModel, _ := reqBody["model"].(string) + reqStream, _ := reqBody["stream"].(bool) + if strings.TrimSpace(reqModel) == "" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream) + return nil, errors.New("model is required") + } + + mappedModel := account.GetMappedModel(reqModel) + if mappedModel != "" && mappedModel != reqModel { + reqModel = mappedModel + } + + modelCfg, ok := GetSoraModelConfig(reqModel) + if !ok { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream) + return nil, fmt.Errorf("unsupported model: %s", reqModel) + } + prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody) + prompt = strings.TrimSpace(prompt) + imageInput = strings.TrimSpace(imageInput) + videoInput = strings.TrimSpace(videoInput) + remixTargetID = strings.TrimSpace(remixTargetID) + + if videoInput != "" && modelCfg.Type != "video" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "video input only supports video models", clientStream) + return nil, errors.New("video input only supports video models") + } + if videoInput != "" && imageInput != "" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "image input and video input cannot be used together", clientStream) + return nil, errors.New("image input and video input cannot be used together") + } + characterOnly := videoInput != "" && prompt == "" + if modelCfg.Type == "prompt_enhance" && prompt == "" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) + return nil, errors.New("prompt is required") + } + if modelCfg.Type != "prompt_enhance" && prompt == "" && !characterOnly { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) + return nil, errors.New("prompt is required") + } + + reqCtx, cancel := s.withSoraTimeout(ctx, reqStream) + if cancel != nil { + defer cancel() + } + if checker, ok := s.soraClient.(soraPreflightChecker); ok && !characterOnly { + if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) + } + } + + if modelCfg.Type == "prompt_enhance" { + enhancedPrompt, err := s.soraClient.EnhancePrompt(reqCtx, account, prompt, modelCfg.ExpansionLevel, modelCfg.DurationS) + if err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) + } + content := strings.TrimSpace(enhancedPrompt) + if content == "" { + content = prompt + } + var firstTokenMs *int + if clientStream { + ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) + if streamErr != nil { + return nil, streamErr + } + firstTokenMs = ms + } else if c != nil { + c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel)) + } + return &ForwardResult{ + RequestID: "", + Model: reqModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: "prompt", + }, nil + } + + characterOpts := parseSoraCharacterOptions(reqBody) + watermarkOpts := parseSoraWatermarkOptions(reqBody) + var characterResult *soraCharacterFlowResult + if videoInput != "" { + videoData, videoErr := decodeSoraVideoInput(reqCtx, videoInput) + if videoErr != nil { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", videoErr.Error(), clientStream) + return nil, videoErr + } + characterResult, videoErr = s.createCharacterFromVideo(reqCtx, account, videoData, characterOpts) + if videoErr != nil { + return nil, s.handleSoraRequestError(ctx, account, videoErr, reqModel, c, clientStream) + } + if characterResult != nil && characterOpts.DeleteAfterGenerate && strings.TrimSpace(characterResult.CharacterID) != "" && !characterOnly { + characterID := strings.TrimSpace(characterResult.CharacterID) + defer func() { + cleanupCtx, cancelCleanup := context.WithTimeout(context.Background(), 15*time.Second) + defer cancelCleanup() + if err := s.soraClient.DeleteCharacter(cleanupCtx, account, characterID); err != nil { + log.Printf("[Sora] cleanup character failed, character_id=%s err=%v", characterID, err) + } + }() + } + if characterOnly { + content := "角色创建成功" + if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" { + content = fmt.Sprintf("角色创建成功,角色名@%s", strings.TrimSpace(characterResult.Username)) + } + var firstTokenMs *int + if clientStream { + ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) + if streamErr != nil { + return nil, streamErr + } + firstTokenMs = ms + } else if c != nil { + resp := buildSoraNonStreamResponse(content, reqModel) + if characterResult != nil { + resp["character_id"] = characterResult.CharacterID + resp["cameo_id"] = characterResult.CameoID + resp["character_username"] = characterResult.Username + resp["character_display_name"] = characterResult.DisplayName + } + c.JSON(http.StatusOK, resp) + } + return &ForwardResult{ + RequestID: "", + Model: reqModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: "prompt", + }, nil + } + if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" { + prompt = fmt.Sprintf("@%s %s", characterResult.Username, prompt) + } + } + + var imageData []byte + imageFilename := "" + if imageInput != "" { + decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput) + if err != nil { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream) + return nil, err + } + imageData = decoded + imageFilename = filename + } + + mediaID := "" + if len(imageData) > 0 { + uploadID, err := s.soraClient.UploadImage(reqCtx, account, imageData, imageFilename) + if err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) + } + mediaID = uploadID + } + + taskID := "" + var err error + videoCount := parseSoraVideoCount(reqBody) + switch modelCfg.Type { + case "image": + taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{ + Prompt: prompt, + Width: modelCfg.Width, + Height: modelCfg.Height, + MediaID: mediaID, + }) + case "video": + if remixTargetID == "" && isSoraStoryboardPrompt(prompt) { + taskID, err = s.soraClient.CreateStoryboardTask(reqCtx, account, SoraStoryboardRequest{ + Prompt: formatSoraStoryboardPrompt(prompt), + Orientation: modelCfg.Orientation, + Frames: modelCfg.Frames, + Model: modelCfg.Model, + Size: modelCfg.Size, + MediaID: mediaID, + }) + } else { + taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{ + Prompt: prompt, + Orientation: modelCfg.Orientation, + Frames: modelCfg.Frames, + Model: modelCfg.Model, + Size: modelCfg.Size, + VideoCount: videoCount, + MediaID: mediaID, + RemixTargetID: remixTargetID, + CameoIDs: extractSoraCameoIDs(reqBody), + }) + } + default: + err = fmt.Errorf("unsupported model type: %s", modelCfg.Type) + } + if err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) + } + + if clientStream && c != nil { + s.prepareSoraStream(c, taskID) + } + + var mediaURLs []string + videoGenerationID := "" + mediaType := modelCfg.Type + imageCount := 0 + imageSize := "" + switch modelCfg.Type { + case "image": + urls, pollErr := s.pollImageTask(reqCtx, c, account, taskID, clientStream) + if pollErr != nil { + return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream) + } + mediaURLs = urls + imageCount = len(urls) + imageSize = soraImageSizeFromModel(reqModel) + case "video": + videoStatus, pollErr := s.pollVideoTaskDetailed(reqCtx, c, account, taskID, clientStream) + if pollErr != nil { + return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream) + } + if videoStatus != nil { + mediaURLs = videoStatus.URLs + videoGenerationID = strings.TrimSpace(videoStatus.GenerationID) + } + default: + mediaType = "prompt" + } + + watermarkPostID := "" + if modelCfg.Type == "video" && watermarkOpts.Enabled { + watermarkURL, postID, watermarkErr := s.resolveWatermarkFreeURL(reqCtx, account, videoGenerationID, watermarkOpts) + if watermarkErr != nil { + if !watermarkOpts.FallbackOnFailure { + return nil, s.handleSoraRequestError(ctx, account, watermarkErr, reqModel, c, clientStream) + } + log.Printf("[Sora] watermark-free fallback to original URL, task_id=%s err=%v", taskID, watermarkErr) + } else if strings.TrimSpace(watermarkURL) != "" { + mediaURLs = []string{strings.TrimSpace(watermarkURL)} + watermarkPostID = strings.TrimSpace(postID) + } + } + + // 直调路径(/sora/v1/chat/completions)保持纯透传,不执行本地/S3 媒体落盘。 + // 媒体存储由客户端 API 路径(/api/v1/sora/generate)的异步流程负责。 + finalURLs := s.normalizeSoraMediaURLs(mediaURLs) + if watermarkPostID != "" && watermarkOpts.DeletePost { + if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil { + log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr) + } + } + + content := buildSoraContent(mediaType, finalURLs) + var firstTokenMs *int + if clientStream { + ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) + if streamErr != nil { + return nil, streamErr + } + firstTokenMs = ms + } else if c != nil { + response := buildSoraNonStreamResponse(content, reqModel) + if len(finalURLs) > 0 { + response["media_url"] = finalURLs[0] + if len(finalURLs) > 1 { + response["media_urls"] = finalURLs + } + } + c.JSON(http.StatusOK, response) + } + + return &ForwardResult{ + RequestID: taskID, + Model: reqModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: mediaType, + MediaURL: firstMediaURL(finalURLs), + ImageCount: imageCount, + ImageSize: imageSize, + }, nil +} + +func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { + if s == nil || s.cfg == nil { + return ctx, nil + } + timeoutSeconds := s.cfg.Gateway.SoraRequestTimeoutSeconds + if stream { + timeoutSeconds = s.cfg.Gateway.SoraStreamTimeoutSeconds + } + if timeoutSeconds <= 0 { + return ctx, nil + } + return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second) +} + +func parseSoraWatermarkOptions(body map[string]any) soraWatermarkOptions { + opts := soraWatermarkOptions{ + Enabled: parseBoolWithDefault(body, "watermark_free", false), + ParseMethod: strings.ToLower(strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_method", "third_party"))), + ParseURL: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_url", "")), + ParseToken: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_token", "")), + FallbackOnFailure: parseBoolWithDefault(body, "watermark_fallback_on_failure", true), + DeletePost: parseBoolWithDefault(body, "watermark_delete_post", false), + } + if opts.ParseMethod == "" { + opts.ParseMethod = "third_party" + } + return opts +} + +func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions { + return soraCharacterOptions{ + SetPublic: parseBoolWithDefault(body, "character_set_public", true), + DeleteAfterGenerate: parseBoolWithDefault(body, "character_delete_after_generate", true), + } +} + +func parseSoraVideoCount(body map[string]any) int { + if body == nil { + return 1 + } + keys := []string{"video_count", "videos", "n_variants"} + for _, key := range keys { + count := parseIntWithDefault(body, key, 0) + if count > 0 { + return clampInt(count, 1, 3) + } + } + return 1 +} + +func parseBoolWithDefault(body map[string]any, key string, def bool) bool { + if body == nil { + return def + } + val, ok := body[key] + if !ok { + return def + } + switch typed := val.(type) { + case bool: + return typed + case int: + return typed != 0 + case int32: + return typed != 0 + case int64: + return typed != 0 + case float64: + return typed != 0 + case string: + typed = strings.ToLower(strings.TrimSpace(typed)) + if typed == "true" || typed == "1" || typed == "yes" { + return true + } + if typed == "false" || typed == "0" || typed == "no" { + return false + } + } + return def +} + +func parseStringWithDefault(body map[string]any, key, def string) string { + if body == nil { + return def + } + val, ok := body[key] + if !ok { + return def + } + if str, ok := val.(string); ok { + return str + } + return def +} + +func parseIntWithDefault(body map[string]any, key string, def int) int { + if body == nil { + return def + } + val, ok := body[key] + if !ok { + return def + } + switch typed := val.(type) { + case int: + return typed + case int32: + return int(typed) + case int64: + return int(typed) + case float64: + return int(typed) + case string: + parsed, err := strconv.Atoi(strings.TrimSpace(typed)) + if err == nil { + return parsed + } + } + return def +} + +func clampInt(v, minVal, maxVal int) int { + if v < minVal { + return minVal + } + if v > maxVal { + return maxVal + } + return v +} + +func extractSoraCameoIDs(body map[string]any) []string { + if body == nil { + return nil + } + raw, ok := body["cameo_ids"] + if !ok { + return nil + } + switch typed := raw.(type) { + case []string: + out := make([]string, 0, len(typed)) + for _, item := range typed { + item = strings.TrimSpace(item) + if item != "" { + out = append(out, item) + } + } + return out + case []any: + out := make([]string, 0, len(typed)) + for _, item := range typed { + str, ok := item.(string) + if !ok { + continue + } + str = strings.TrimSpace(str) + if str != "" { + out = append(out, str) + } + } + return out + default: + return nil + } +} + +func (s *SoraGatewayService) createCharacterFromVideo(ctx context.Context, account *Account, videoData []byte, opts soraCharacterOptions) (*soraCharacterFlowResult, error) { + cameoID, err := s.soraClient.UploadCharacterVideo(ctx, account, videoData) + if err != nil { + return nil, err + } + + cameoStatus, err := s.pollCameoStatus(ctx, account, cameoID) + if err != nil { + return nil, err + } + username := processSoraCharacterUsername(cameoStatus.UsernameHint) + displayName := strings.TrimSpace(cameoStatus.DisplayNameHint) + if displayName == "" { + displayName = "Character" + } + profileAssetURL := strings.TrimSpace(cameoStatus.ProfileAssetURL) + if profileAssetURL == "" { + return nil, errors.New("profile asset url not found in cameo status") + } + + avatarData, err := s.soraClient.DownloadCharacterImage(ctx, account, profileAssetURL) + if err != nil { + return nil, err + } + assetPointer, err := s.soraClient.UploadCharacterImage(ctx, account, avatarData) + if err != nil { + return nil, err + } + instructionSet := cameoStatus.InstructionSetHint + if instructionSet == nil { + instructionSet = cameoStatus.InstructionSet + } + + characterID, err := s.soraClient.FinalizeCharacter(ctx, account, SoraCharacterFinalizeRequest{ + CameoID: strings.TrimSpace(cameoID), + Username: username, + DisplayName: displayName, + ProfileAssetPointer: assetPointer, + InstructionSet: instructionSet, + }) + if err != nil { + return nil, err + } + + if opts.SetPublic { + if err := s.soraClient.SetCharacterPublic(ctx, account, cameoID); err != nil { + return nil, err + } + } + + return &soraCharacterFlowResult{ + CameoID: strings.TrimSpace(cameoID), + CharacterID: strings.TrimSpace(characterID), + Username: strings.TrimSpace(username), + DisplayName: displayName, + }, nil +} + +func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) { + timeout := 10 * time.Minute + interval := 5 * time.Second + maxAttempts := int(math.Ceil(timeout.Seconds() / interval.Seconds())) + if maxAttempts < 1 { + maxAttempts = 1 + } + + var lastErr error + consecutiveErrors := 0 + for attempt := 0; attempt < maxAttempts; attempt++ { + status, err := s.soraClient.GetCameoStatus(ctx, account, cameoID) + if err != nil { + lastErr = err + consecutiveErrors++ + if consecutiveErrors >= 3 { + break + } + if attempt < maxAttempts-1 { + if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { + return nil, sleepErr + } + } + continue + } + consecutiveErrors = 0 + if status == nil { + if attempt < maxAttempts-1 { + if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { + return nil, sleepErr + } + } + continue + } + currentStatus := strings.ToLower(strings.TrimSpace(status.Status)) + statusMessage := strings.TrimSpace(status.StatusMessage) + if currentStatus == "failed" { + if statusMessage == "" { + statusMessage = "character creation failed" + } + return nil, errors.New(statusMessage) + } + if strings.EqualFold(statusMessage, "Completed") || currentStatus == "finalized" { + return status, nil + } + if attempt < maxAttempts-1 { + if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { + return nil, sleepErr + } + } + } + if lastErr != nil { + return nil, fmt.Errorf("poll cameo status failed: %w", lastErr) + } + return nil, errors.New("cameo processing timeout") +} + +func processSoraCharacterUsername(usernameHint string) string { + usernameHint = strings.TrimSpace(usernameHint) + if usernameHint == "" { + usernameHint = "character" + } + if strings.Contains(usernameHint, ".") { + parts := strings.Split(usernameHint, ".") + usernameHint = strings.TrimSpace(parts[len(parts)-1]) + } + if usernameHint == "" { + usernameHint = "character" + } + return fmt.Sprintf("%s%d", usernameHint, rand.Intn(900)+100) +} + +func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) { + generationID = strings.TrimSpace(generationID) + if generationID == "" { + return "", "", errors.New("generation id is required for watermark-free mode") + } + postID, err := s.soraClient.PostVideoForWatermarkFree(ctx, account, generationID) + if err != nil { + return "", "", err + } + postID = strings.TrimSpace(postID) + if postID == "" { + return "", "", errors.New("watermark-free publish returned empty post id") + } + + switch opts.ParseMethod { + case "custom": + urlVal, parseErr := s.soraClient.GetWatermarkFreeURLCustom(ctx, account, opts.ParseURL, opts.ParseToken, postID) + if parseErr != nil { + return "", postID, parseErr + } + return strings.TrimSpace(urlVal), postID, nil + case "", "third_party": + return fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID), postID, nil + default: + return "", postID, fmt.Errorf("unsupported watermark parse method: %s", opts.ParseMethod) + } +} + +func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool { + switch statusCode { + case 401, 402, 403, 404, 429, 529: + return true + default: + return statusCode >= 500 + } +} + +func buildSoraNonStreamResponse(content, model string) map[string]any { + return map[string]any{ + "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), + "object": "chat.completion", + "created": time.Now().Unix(), + "model": model, + "choices": []any{ + map[string]any{ + "index": 0, + "message": map[string]any{ + "role": "assistant", + "content": content, + }, + "finish_reason": "stop", + }, + }, + } +} + +func soraImageSizeFromModel(model string) string { + modelLower := strings.ToLower(model) + if size, ok := soraImageSizeMap[modelLower]; ok { + return size + } + if strings.Contains(modelLower, "landscape") || strings.Contains(modelLower, "portrait") { + return "540" + } + return "360" +} + +func soraProErrorMessage(model, upstreamMsg string) string { + modelLower := strings.ToLower(model) + if strings.Contains(modelLower, "sora2pro-hd") { + return "当前账号无法使用 Sora Pro-HD 模型,请更换模型或账号" + } + if strings.Contains(modelLower, "sora2pro") { + return "当前账号无法使用 Sora Pro 模型,请更换模型或账号" + } + return "" +} + +func firstMediaURL(urls []string) string { + if len(urls) == 0 { + return "" + } + return urls[0] +} + +func (s *SoraGatewayService) buildSoraMediaURL(path string, rawQuery string) string { + if path == "" { + return path + } + prefix := "/sora/media" + values := url.Values{} + if rawQuery != "" { + if parsed, err := url.ParseQuery(rawQuery); err == nil { + values = parsed + } + } + + signKey := "" + ttlSeconds := 0 + if s != nil && s.cfg != nil { + signKey = strings.TrimSpace(s.cfg.Gateway.SoraMediaSigningKey) + ttlSeconds = s.cfg.Gateway.SoraMediaSignedURLTTLSeconds + } + values.Del("sig") + values.Del("expires") + signingQuery := values.Encode() + if signKey != "" && ttlSeconds > 0 { + expires := time.Now().Add(time.Duration(ttlSeconds) * time.Second).Unix() + signature := SignSoraMediaURL(path, signingQuery, expires, signKey) + if signature != "" { + values.Set("expires", strconv.FormatInt(expires, 10)) + values.Set("sig", signature) + prefix = "/sora/media-signed" + } + } + + encoded := values.Encode() + if encoded == "" { + return prefix + path + } + return prefix + path + "?" + encoded +} + +func (s *SoraGatewayService) prepareSoraStream(c *gin.Context, requestID string) { + if c == nil { + return + } + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + if strings.TrimSpace(requestID) != "" { + c.Header("x-request-id", requestID) + } +} + +func (s *SoraGatewayService) writeSoraStream(c *gin.Context, model, content string, startTime time.Time) (*int, error) { + if c == nil { + return nil, nil + } + writer := c.Writer + flusher, _ := writer.(http.Flusher) + + chunk := map[string]any{ + "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []any{ + map[string]any{ + "index": 0, + "delta": map[string]any{ + "content": content, + }, + }, + }, + } + encoded, _ := jsonMarshalRaw(chunk) + if _, err := fmt.Fprintf(writer, "data: %s\n\n", encoded); err != nil { + return nil, err + } + if flusher != nil { + flusher.Flush() + } + ms := int(time.Since(startTime).Milliseconds()) + finalChunk := map[string]any{ + "id": chunk["id"], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []any{ + map[string]any{ + "index": 0, + "delta": map[string]any{}, + "finish_reason": "stop", + }, + }, + } + finalEncoded, _ := jsonMarshalRaw(finalChunk) + if _, err := fmt.Fprintf(writer, "data: %s\n\n", finalEncoded); err != nil { + return &ms, err + } + if _, err := fmt.Fprint(writer, "data: [DONE]\n\n"); err != nil { + return &ms, err + } + if flusher != nil { + flusher.Flush() + } + return &ms, nil +} + +func (s *SoraGatewayService) writeSoraError(c *gin.Context, status int, errType, message string, stream bool) { + if c == nil { + return + } + if stream { + flusher, _ := c.Writer.(http.Flusher) + errorData := map[string]any{ + "error": map[string]string{ + "type": errType, + "message": message, + }, + } + jsonBytes, err := json.Marshal(errorData) + if err != nil { + _ = c.Error(err) + return + } + errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes)) + _, _ = fmt.Fprint(c.Writer, errorEvent) + _, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n") + if flusher != nil { + flusher.Flush() + } + return + } + c.JSON(status, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} + +func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account *Account, err error, model string, c *gin.Context, stream bool) error { + if err == nil { + return nil + } + var upstreamErr *SoraUpstreamError + if errors.As(err, &upstreamErr) { + accountID := int64(0) + if account != nil { + accountID = account.ID + } + logger.LegacyPrintf( + "service.sora", + "[SoraRawError] account_id=%d model=%s status=%d request_id=%s cf_ray=%s message=%s raw_body=%s", + accountID, + model, + upstreamErr.StatusCode, + strings.TrimSpace(upstreamErr.Headers.Get("x-request-id")), + strings.TrimSpace(upstreamErr.Headers.Get("cf-ray")), + strings.TrimSpace(upstreamErr.Message), + truncateForLog(upstreamErr.Body, 1024), + ) + if s.rateLimitService != nil && account != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body) + } + if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) { + var responseHeaders http.Header + if upstreamErr.Headers != nil { + responseHeaders = upstreamErr.Headers.Clone() + } + return &UpstreamFailoverError{ + StatusCode: upstreamErr.StatusCode, + ResponseBody: upstreamErr.Body, + ResponseHeaders: responseHeaders, + } + } + msg := upstreamErr.Message + if override := soraProErrorMessage(model, msg); override != "" { + msg = override + } + s.writeSoraError(c, upstreamErr.StatusCode, "upstream_error", msg, stream) + return err + } + if errors.Is(err, context.DeadlineExceeded) { + s.writeSoraError(c, http.StatusGatewayTimeout, "timeout_error", "Sora generation timeout", stream) + return err + } + s.writeSoraError(c, http.StatusBadGateway, "api_error", err.Error(), stream) + return err +} + +func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) { + interval := s.pollInterval() + maxAttempts := s.pollMaxAttempts() + lastPing := time.Now() + for attempt := 0; attempt < maxAttempts; attempt++ { + status, err := s.soraClient.GetImageTask(ctx, account, taskID) + if err != nil { + return nil, err + } + switch strings.ToLower(status.Status) { + case "succeeded", "completed": + return status.URLs, nil + case "failed": + if status.ErrorMsg != "" { + return nil, errors.New(status.ErrorMsg) + } + return nil, errors.New("sora image generation failed") + } + if stream { + s.maybeSendPing(c, &lastPing) + } + if err := sleepWithContext(ctx, interval); err != nil { + return nil, err + } + } + return nil, errors.New("sora image generation timeout") +} + +func (s *SoraGatewayService) pollVideoTaskDetailed(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) (*SoraVideoTaskStatus, error) { + interval := s.pollInterval() + maxAttempts := s.pollMaxAttempts() + lastPing := time.Now() + for attempt := 0; attempt < maxAttempts; attempt++ { + status, err := s.soraClient.GetVideoTask(ctx, account, taskID) + if err != nil { + return nil, err + } + switch strings.ToLower(status.Status) { + case "completed", "succeeded": + return status, nil + case "failed": + if status.ErrorMsg != "" { + return nil, errors.New(status.ErrorMsg) + } + return nil, errors.New("sora video generation failed") + } + if stream { + s.maybeSendPing(c, &lastPing) + } + if err := sleepWithContext(ctx, interval); err != nil { + return nil, err + } + } + return nil, errors.New("sora video generation timeout") +} + +func (s *SoraGatewayService) pollInterval() time.Duration { + if s == nil || s.cfg == nil { + return 2 * time.Second + } + interval := s.cfg.Sora.Client.PollIntervalSeconds + if interval <= 0 { + interval = 2 + } + return time.Duration(interval) * time.Second +} + +func (s *SoraGatewayService) pollMaxAttempts() int { + if s == nil || s.cfg == nil { + return 600 + } + maxAttempts := s.cfg.Sora.Client.MaxPollAttempts + if maxAttempts <= 0 { + maxAttempts = 600 + } + return maxAttempts +} + +func (s *SoraGatewayService) maybeSendPing(c *gin.Context, lastPing *time.Time) { + if c == nil { + return + } + interval := 10 * time.Second + if s != nil && s.cfg != nil && s.cfg.Concurrency.PingInterval > 0 { + interval = time.Duration(s.cfg.Concurrency.PingInterval) * time.Second + } + if time.Since(*lastPing) < interval { + return + } + if _, err := fmt.Fprint(c.Writer, ":\n\n"); err == nil { + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } + *lastPing = time.Now() + } +} + +func (s *SoraGatewayService) normalizeSoraMediaURLs(urls []string) []string { + if len(urls) == 0 { + return urls + } + output := make([]string, 0, len(urls)) + for _, raw := range urls { + raw = strings.TrimSpace(raw) + if raw == "" { + continue + } + if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { + output = append(output, raw) + continue + } + pathVal := raw + if !strings.HasPrefix(pathVal, "/") { + pathVal = "/" + pathVal + } + output = append(output, s.buildSoraMediaURL(pathVal, "")) + } + return output +} + +// jsonMarshalRaw 序列化 JSON,不转义 &、<、> 等 HTML 字符, +// 避免 URL 中的 & 被转义为 \u0026 导致客户端无法直接使用。 +func jsonMarshalRaw(v any) ([]byte, error) { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetEscapeHTML(false) + if err := enc.Encode(v); err != nil { + return nil, err + } + // Encode 会追加换行符,去掉它 + b := buf.Bytes() + if len(b) > 0 && b[len(b)-1] == '\n' { + b = b[:len(b)-1] + } + return b, nil +} + +func buildSoraContent(mediaType string, urls []string) string { + switch mediaType { + case "image": + parts := make([]string, 0, len(urls)) + for _, u := range urls { + parts = append(parts, fmt.Sprintf("![image](%s)", u)) + } + return strings.Join(parts, "\n") + case "video": + if len(urls) == 0 { + return "" + } + return fmt.Sprintf("```html\n\n```", urls[0]) + default: + return "" + } +} + +func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remixTargetID string) { + if body == nil { + return "", "", "", "" + } + if v, ok := body["remix_target_id"].(string); ok { + remixTargetID = strings.TrimSpace(v) + } + if v, ok := body["image"].(string); ok { + imageInput = v + } + if v, ok := body["video"].(string); ok { + videoInput = v + } + if v, ok := body["prompt"].(string); ok && strings.TrimSpace(v) != "" { + prompt = v + } + if messages, ok := body["messages"].([]any); ok { + builder := strings.Builder{} + for _, raw := range messages { + msg, ok := raw.(map[string]any) + if !ok { + continue + } + role, _ := msg["role"].(string) + if role != "" && role != "user" { + continue + } + content := msg["content"] + text, img, vid := parseSoraMessageContent(content) + if text != "" { + if builder.Len() > 0 { + _, _ = builder.WriteString("\n") + } + _, _ = builder.WriteString(text) + } + if imageInput == "" && img != "" { + imageInput = img + } + if videoInput == "" && vid != "" { + videoInput = vid + } + } + if prompt == "" { + prompt = builder.String() + } + } + if remixTargetID == "" { + remixTargetID = extractRemixTargetIDFromPrompt(prompt) + } + prompt = cleanRemixLinkFromPrompt(prompt) + return prompt, imageInput, videoInput, remixTargetID +} + +func parseSoraMessageContent(content any) (text, imageInput, videoInput string) { + switch val := content.(type) { + case string: + return val, "", "" + case []any: + builder := strings.Builder{} + for _, item := range val { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + t, _ := itemMap["type"].(string) + switch t { + case "text": + if txt, ok := itemMap["text"].(string); ok && strings.TrimSpace(txt) != "" { + if builder.Len() > 0 { + _, _ = builder.WriteString("\n") + } + _, _ = builder.WriteString(txt) + } + case "image_url": + if imageInput == "" { + if urlVal, ok := itemMap["image_url"].(map[string]any); ok { + imageInput = fmt.Sprintf("%v", urlVal["url"]) + } else if urlStr, ok := itemMap["image_url"].(string); ok { + imageInput = urlStr + } + } + case "video_url": + if videoInput == "" { + if urlVal, ok := itemMap["video_url"].(map[string]any); ok { + videoInput = fmt.Sprintf("%v", urlVal["url"]) + } else if urlStr, ok := itemMap["video_url"].(string); ok { + videoInput = urlStr + } + } + } + } + return builder.String(), imageInput, videoInput + default: + return "", "", "" + } +} + +func isSoraStoryboardPrompt(prompt string) bool { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return false + } + return len(soraStoryboardPattern.FindAllString(prompt, -1)) >= 1 +} + +func formatSoraStoryboardPrompt(prompt string) string { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return "" + } + matches := soraStoryboardShotPattern.FindAllStringSubmatch(prompt, -1) + if len(matches) == 0 { + return prompt + } + firstBracketPos := strings.Index(prompt, "[") + instructions := "" + if firstBracketPos > 0 { + instructions = strings.TrimSpace(prompt[:firstBracketPos]) + } + shots := make([]string, 0, len(matches)) + for i, match := range matches { + if len(match) < 3 { + continue + } + duration := strings.TrimSpace(match[1]) + scene := strings.TrimSpace(match[2]) + if scene == "" { + continue + } + shots = append(shots, fmt.Sprintf("Shot %d:\nduration: %ssec\nScene: %s", i+1, duration, scene)) + } + if len(shots) == 0 { + return prompt + } + timeline := strings.Join(shots, "\n\n") + if instructions == "" { + return timeline + } + return fmt.Sprintf("current timeline:\n%s\n\ninstructions:\n%s", timeline, instructions) +} + +func extractRemixTargetIDFromPrompt(prompt string) string { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return "" + } + return strings.TrimSpace(soraRemixTargetPattern.FindString(prompt)) +} + +func cleanRemixLinkFromPrompt(prompt string) string { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return prompt + } + cleaned := soraRemixTargetInURLPattern.ReplaceAllString(prompt, "") + cleaned = soraRemixTargetPattern.ReplaceAllString(cleaned, "") + cleaned = strings.Join(strings.Fields(cleaned), " ") + return strings.TrimSpace(cleaned) +} + +func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) { + raw := strings.TrimSpace(input) + if raw == "" { + return nil, "", errors.New("empty image input") + } + if strings.HasPrefix(raw, "data:") { + parts := strings.SplitN(raw, ",", 2) + if len(parts) != 2 { + return nil, "", errors.New("invalid data url") + } + meta := parts[0] + payload := parts[1] + decoded, err := decodeBase64WithLimit(payload, soraImageInputMaxBytes) + if err != nil { + return nil, "", err + } + ext := "" + if strings.HasPrefix(meta, "data:") { + metaParts := strings.SplitN(meta[5:], ";", 2) + if len(metaParts) > 0 { + if exts, err := mime.ExtensionsByType(metaParts[0]); err == nil && len(exts) > 0 { + ext = exts[0] + } + } + } + filename := "image" + ext + return decoded, filename, nil + } + if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { + return downloadSoraImageInput(ctx, raw) + } + decoded, err := decodeBase64WithLimit(raw, soraImageInputMaxBytes) + if err != nil { + return nil, "", errors.New("invalid base64 image") + } + return decoded, "image.png", nil +} + +func decodeSoraVideoInput(ctx context.Context, input string) ([]byte, error) { + raw := strings.TrimSpace(input) + if raw == "" { + return nil, errors.New("empty video input") + } + if strings.HasPrefix(raw, "data:") { + parts := strings.SplitN(raw, ",", 2) + if len(parts) != 2 { + return nil, errors.New("invalid video data url") + } + decoded, err := decodeBase64WithLimit(parts[1], soraVideoInputMaxBytes) + if err != nil { + return nil, errors.New("invalid base64 video") + } + if len(decoded) == 0 { + return nil, errors.New("empty video data") + } + return decoded, nil + } + if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { + return downloadSoraVideoInput(ctx, raw) + } + decoded, err := decodeBase64WithLimit(raw, soraVideoInputMaxBytes) + if err != nil { + return nil, errors.New("invalid base64 video") + } + if len(decoded) == 0 { + return nil, errors.New("empty video data") + } + return decoded, nil +} + +func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) { + parsed, err := validateSoraRemoteURL(rawURL) + if err != nil { + return nil, "", err + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil) + if err != nil { + return nil, "", err + } + client := &http.Client{ + Timeout: soraImageInputTimeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= soraImageInputMaxRedirects { + return errors.New("too many redirects") + } + return validateSoraRemoteURLValue(req.URL) + }, + } + resp, err := client.Do(req) + if err != nil { + return nil, "", err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, "", fmt.Errorf("download image failed: %d", resp.StatusCode) + } + data, err := io.ReadAll(io.LimitReader(resp.Body, soraImageInputMaxBytes)) + if err != nil { + return nil, "", err + } + ext := fileExtFromURL(parsed.String()) + if ext == "" { + ext = fileExtFromContentType(resp.Header.Get("Content-Type")) + } + filename := "image" + ext + return data, filename, nil +} + +func downloadSoraVideoInput(ctx context.Context, rawURL string) ([]byte, error) { + parsed, err := validateSoraRemoteURL(rawURL) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil) + if err != nil { + return nil, err + } + client := &http.Client{ + Timeout: soraVideoInputTimeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= soraVideoInputMaxRedirects { + return errors.New("too many redirects") + } + return validateSoraRemoteURLValue(req.URL) + }, + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("download video failed: %d", resp.StatusCode) + } + data, err := io.ReadAll(io.LimitReader(resp.Body, soraVideoInputMaxBytes)) + if err != nil { + return nil, err + } + if len(data) == 0 { + return nil, errors.New("empty video content") + } + return data, nil +} + +func decodeBase64WithLimit(encoded string, maxBytes int64) ([]byte, error) { + if maxBytes <= 0 { + return nil, errors.New("invalid max bytes limit") + } + decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) + limited := io.LimitReader(decoder, maxBytes+1) + data, err := io.ReadAll(limited) + if err != nil { + return nil, err + } + if int64(len(data)) > maxBytes { + return nil, fmt.Errorf("input exceeds %d bytes limit", maxBytes) + } + return data, nil +} + +func validateSoraRemoteURL(raw string) (*url.URL, error) { + if strings.TrimSpace(raw) == "" { + return nil, errors.New("empty remote url") + } + parsed, err := url.Parse(raw) + if err != nil { + return nil, fmt.Errorf("invalid remote url: %w", err) + } + if err := validateSoraRemoteURLValue(parsed); err != nil { + return nil, err + } + return parsed, nil +} + +func validateSoraRemoteURLValue(parsed *url.URL) error { + if parsed == nil { + return errors.New("invalid remote url") + } + scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme)) + if scheme != "http" && scheme != "https" { + return errors.New("only http/https remote url is allowed") + } + if parsed.User != nil { + return errors.New("remote url cannot contain userinfo") + } + host := strings.ToLower(strings.TrimSpace(parsed.Hostname())) + if host == "" { + return errors.New("remote url missing host") + } + if _, blocked := soraBlockedHostnames[host]; blocked { + return errors.New("remote url is not allowed") + } + if ip := net.ParseIP(host); ip != nil { + if isSoraBlockedIP(ip) { + return errors.New("remote url is not allowed") + } + return nil + } + ips, err := net.LookupIP(host) + if err != nil { + return fmt.Errorf("resolve remote url failed: %w", err) + } + for _, ip := range ips { + if isSoraBlockedIP(ip) { + return errors.New("remote url is not allowed") + } + } + return nil +} + +func isSoraBlockedIP(ip net.IP) bool { + if ip == nil { + return true + } + for _, cidr := range soraBlockedCIDRs { + if cidr.Contains(ip) { + return true + } + } + return false +} + +func mustParseCIDRs(values []string) []*net.IPNet { + out := make([]*net.IPNet, 0, len(values)) + for _, val := range values { + _, cidr, err := net.ParseCIDR(val) + if err != nil { + continue + } + out = append(out, cidr) + } + return out +} diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go new file mode 100644 index 00000000..206636ff --- /dev/null +++ b/backend/internal/service/sora_gateway_service_test.go @@ -0,0 +1,558 @@ +//go:build unit + +package service + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +var _ SoraClient = (*stubSoraClientForPoll)(nil) + +type stubSoraClientForPoll struct { + imageStatus *SoraImageTaskStatus + videoStatus *SoraVideoTaskStatus + imageCalls int + videoCalls int + enhanced string + enhanceErr error + storyboard bool + videoReq SoraVideoRequest + parseErr error + postCalls int + deleteCalls int +} + +func (s *stubSoraClientForPoll) Enabled() bool { return true } +func (s *stubSoraClientForPoll) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) { + return "", nil +} +func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) { + return "task-image", nil +} +func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) { + s.videoReq = req + return "task-video", nil +} +func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) { + s.storyboard = true + return "task-video", nil +} +func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) { + return "cameo-1", nil +} +func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) { + return &SoraCameoStatus{ + Status: "finalized", + StatusMessage: "Completed", + DisplayNameHint: "Character", + UsernameHint: "user.character", + ProfileAssetURL: "https://example.com/avatar.webp", + }, nil +} +func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) { + return []byte("avatar"), nil +} +func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) { + return "asset-pointer", nil +} +func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) { + return "character-1", nil +} +func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error { + return nil +} +func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error { + return nil +} +func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) { + s.postCalls++ + return "s_post", nil +} +func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error { + s.deleteCalls++ + return nil +} +func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) { + if s.parseErr != nil { + return "", s.parseErr + } + return "https://example.com/no-watermark.mp4", nil +} +func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) { + if s.enhanced != "" { + return s.enhanced, s.enhanceErr + } + return "enhanced prompt", s.enhanceErr +} +func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { + s.imageCalls++ + return s.imageStatus, nil +} +func (s *stubSoraClientForPoll) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) { + s.videoCalls++ + return s.videoStatus, nil +} + +func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) { + client := &stubSoraClientForPoll{ + imageStatus: &SoraImageTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/a.png"}, + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + service := NewSoraGatewayService(client, nil, nil, cfg) + + urls, err := service.pollImageTask(context.Background(), nil, &Account{ID: 1}, "task", false) + require.NoError(t, err) + require.Equal(t, []string{"https://example.com/a.png"}, urls) + require.Equal(t, 1, client.imageCalls) +} + +func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) { + client := &stubSoraClientForPoll{ + enhanced: "cinematic prompt", + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ + ID: 1, + Platform: PlatformSora, + Status: StatusActive, + } + body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "prompt", result.MediaType) + require.Equal(t, "prompt-enhance-short-10s", result.Model) +} + +func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/v.mp4"}, + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, client.storyboard) +} + +func TestSoraGatewayService_ForwardVideoCount(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/v.mp4"}, + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"video_count":3,"stream":false}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 3, client.videoReq.VideoCount) +} + +func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) { + client := &stubSoraClientForPoll{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "prompt", result.MediaType) + require.Equal(t, 0, client.videoCalls) +} + +func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/original.mp4"}, + GenerationID: "gen_1", + }, + parseErr: errors.New("parse failed"), + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "https://example.com/original.mp4", result.MediaURL) + require.Equal(t, 1, client.postCalls) + require.Equal(t, 0, client.deleteCalls) +} + +func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/original.mp4"}, + GenerationID: "gen_1", + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL) + require.Equal(t, 1, client.postCalls) + require.Equal(t, 1, client.deleteCalls) +} + +func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "failed", + ErrorMsg: "reject", + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + service := NewSoraGatewayService(client, nil, nil, cfg) + + status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false) + require.Error(t, err) + require.Nil(t, status) + require.Contains(t, err.Error(), "reject") + require.Equal(t, 1, client.videoCalls) +} + +func TestSoraGatewayService_BuildSoraMediaURLSigned(t *testing.T) { + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + SoraMediaSigningKey: "test-key", + SoraMediaSignedURLTTLSeconds: 600, + }, + } + service := NewSoraGatewayService(nil, nil, nil, cfg) + + url := service.buildSoraMediaURL("/image/2025/01/01/a.png", "") + require.Contains(t, url, "/sora/media-signed") + require.Contains(t, url, "expires=") + require.Contains(t, url, "sig=") +} + +func TestNormalizeSoraMediaURLs_Empty(t *testing.T) { + svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) + result := svc.normalizeSoraMediaURLs(nil) + require.Empty(t, result) + + result = svc.normalizeSoraMediaURLs([]string{}) + require.Empty(t, result) +} + +func TestNormalizeSoraMediaURLs_HTTPUrls(t *testing.T) { + svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) + urls := []string{"https://example.com/a.png", "http://example.com/b.mp4"} + result := svc.normalizeSoraMediaURLs(urls) + require.Equal(t, urls, result) +} + +func TestNormalizeSoraMediaURLs_LocalPaths(t *testing.T) { + cfg := &config.Config{} + svc := NewSoraGatewayService(nil, nil, nil, cfg) + urls := []string{"/image/2025/01/a.png", "video/2025/01/b.mp4"} + result := svc.normalizeSoraMediaURLs(urls) + require.Len(t, result, 2) + require.Contains(t, result[0], "/sora/media") + require.Contains(t, result[1], "/sora/media") +} + +func TestNormalizeSoraMediaURLs_SkipsBlank(t *testing.T) { + svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) + urls := []string{"https://example.com/a.png", "", " ", "https://example.com/b.png"} + result := svc.normalizeSoraMediaURLs(urls) + require.Len(t, result, 2) +} + +func TestBuildSoraContent_Image(t *testing.T) { + content := buildSoraContent("image", []string{"https://a.com/1.png", "https://a.com/2.png"}) + require.Contains(t, content, "![image](https://a.com/1.png)") + require.Contains(t, content, "![image](https://a.com/2.png)") +} + +func TestBuildSoraContent_Video(t *testing.T) { + content := buildSoraContent("video", []string{"https://a.com/v.mp4"}) + require.Contains(t, content, "